/
openapi.py
302 lines (256 loc) · 11.3 KB
/
openapi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""Utilities for generating OpenAPI Specification (fka Swagger) entities from
marshmallow :class:`Schemas <marshmallow.Schema>` and :class:`Fields <marshmallow.fields.Field>`.
.. warning::
This module is treated as private API.
Users should not need to use this module directly.
"""
from __future__ import annotations
import typing
from packaging.version import Version
import marshmallow
from marshmallow.utils import is_collection
from apispec import APISpec
from apispec.exceptions import APISpecError
from .field_converter import FieldConverterMixin
from .common import (
get_fields,
make_schema_key,
resolve_schema_instance,
get_unique_schema_name,
)
__location_map__ = {
"match_info": "path",
"query": "query",
"querystring": "query",
"json": "body",
"headers": "header",
"cookies": "cookie",
"form": "formData",
"files": "formData",
}
class OpenAPIConverter(FieldConverterMixin):
"""Adds methods for generating OpenAPI specification from marshmallow schemas and fields.
:param Version|str openapi_version: The OpenAPI version to use.
Should be in the form '2.x' or '3.x.x' to comply with the OpenAPI standard.
:param callable schema_name_resolver: Callable to generate the schema definition name.
Receives the `Schema` class and returns the name to be used in refs within
the generated spec. When working with circular referencing this function
must must not return `None` for schemas in a circular reference chain.
:param APISpec spec: An initalied spec. Nested schemas will be added to the
spec
"""
def __init__(
self,
openapi_version: Version | str,
schema_name_resolver,
spec: APISpec,
) -> None:
self.openapi_version = (
Version(openapi_version)
if isinstance(openapi_version, str)
else openapi_version
)
self.schema_name_resolver = schema_name_resolver
self.spec = spec
self.init_attribute_functions()
self.init_parameter_attribute_functions()
# Schema references
self.refs: dict = {}
def init_parameter_attribute_functions(self) -> None:
self.parameter_attribute_functions = [
self.field2required,
self.list2param,
]
def add_parameter_attribute_function(self, func) -> None:
"""Method to add a field parameter function to the list of field
parameter functions that will be called on a field to convert it to a
field parameter.
:param func func: the field parameter function to add
The attribute function will be bound to the
`OpenAPIConverter <apispec.ext.marshmallow.openapi.OpenAPIConverter>`
instance.
It will be called for each field in a schema with
`self <apispec.ext.marshmallow.openapi.OpenAPIConverter>` and a
`field <marshmallow.fields.Field>` instance
positional arguments and `ret <dict>` keyword argument.
May mutate `ret`.
User added field parameter functions will be called after all built-in
field parameter functions in the order they were added.
"""
bound_func = func.__get__(self)
setattr(self, func.__name__, bound_func)
self.parameter_attribute_functions.append(bound_func)
def resolve_nested_schema(self, schema):
"""Return the OpenAPI representation of a marshmallow Schema.
Adds the schema to the spec if it isn't already present.
Typically will return a dictionary with the reference to the schema's
path in the spec unless the `schema_name_resolver` returns `None`, in
which case the returned dictionary will contain a JSON Schema Object
representation of the schema.
:param schema: schema to add to the spec
"""
try:
schema_instance = resolve_schema_instance(schema)
# If schema is a string and is not found in registry,
# assume it is a schema reference
except marshmallow.exceptions.RegistryError:
return schema
schema_key = make_schema_key(schema_instance)
if schema_key not in self.refs:
name = self.schema_name_resolver(schema)
if not name:
try:
json_schema = self.schema2jsonschema(schema_instance)
except RuntimeError as exc:
raise APISpecError(
"Name resolver returned None for schema {schema} which is "
"part of a chain of circular referencing schemas. Please"
" ensure that the schema_name_resolver passed to"
" MarshmallowPlugin returns a string for all circular"
" referencing schemas.".format(schema=schema)
) from exc
if getattr(schema, "many", False):
return {"type": "array", "items": json_schema}
return json_schema
name = get_unique_schema_name(self.spec.components, name)
self.spec.components.schema(name, schema=schema)
return self.get_ref_dict(schema_instance)
def schema2parameters(
self,
schema,
*,
location,
name: str = "body",
required: bool = False,
description: str | None = None,
):
"""Return an array of OpenAPI parameters given a given marshmallow
:class:`Schema <marshmallow.Schema>`. If `location` is "body", then return an array
of a single parameter; else return an array of a parameter for each included field in
the :class:`Schema <marshmallow.Schema>`.
In OpenAPI 3, only "query", "header", "path" or "cookie" are allowed for the location
of parameters. "requestBody" is used when fields are in the body.
https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#parameterObject
"""
location = __location_map__.get(location, location)
# OAS 2 body parameter
if location == "body":
param = {
"in": location,
"required": required,
"name": name,
"schema": self.resolve_nested_schema(schema),
}
if description:
param["description"] = description
return [param]
assert not getattr(
schema, "many", False
), "Schemas with many=True are only supported for 'json' location (aka 'in: body')"
fields = get_fields(schema, exclude_dump_only=True)
return [
self._field2parameter(
field_obj,
name=field_obj.data_key or field_name,
location=location,
)
for field_name, field_obj in fields.items()
]
def _field2parameter(
self, field: marshmallow.fields.Field, *, name: str, location: str
) -> dict:
"""Return an OpenAPI parameter as a `dict`, given a marshmallow
:class:`Field <marshmallow.Field>`.
https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#parameterObject
"""
ret: dict = {"in": location, "name": name}
prop = self.field2property(field)
if self.openapi_version.major < 3:
ret.update(prop)
else:
if "description" in prop:
ret["description"] = prop.pop("description")
if "deprecated" in prop:
ret["deprecated"] = prop.pop("deprecated")
ret["schema"] = prop
for param_attr_func in self.parameter_attribute_functions:
ret.update(param_attr_func(field, ret=ret))
return ret
def field2required(
self, field: marshmallow.fields.Field, **kwargs: typing.Any
) -> dict:
"""Return the dictionary of OpenAPI parameter attributes for a required field.
:param Field field: A marshmallow field.
:rtype: dict
"""
ret = {}
partial = getattr(field.parent, "partial", False)
ret["required"] = field.required and (
not partial
or (is_collection(partial) and field.name not in partial) # type:ignore
)
return ret
def list2param(self, field: marshmallow.fields.Field, **kwargs: typing.Any) -> dict:
"""Return a dictionary of parameter properties from
:class:`List <marshmallow.fields.List` fields.
:param Field field: A marshmallow field.
:rtype: dict
"""
ret: dict = {}
if isinstance(field, marshmallow.fields.List):
if self.openapi_version.major < 3:
ret["collectionFormat"] = "multi"
else:
ret["explode"] = True
ret["style"] = "form"
return ret
def schema2jsonschema(self, schema):
"""Return the JSON Schema Object for a given marshmallow
:class:`Schema <marshmallow.Schema>` instance. Schema may optionally
provide the ``title`` and ``description`` class Meta options.
https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schemaObject
:param Schema schema: A marshmallow Schema instance
:rtype: dict, a JSON Schema Object
"""
fields = get_fields(schema)
Meta = getattr(schema, "Meta", None)
partial = getattr(schema, "partial", None)
jsonschema = self.fields2jsonschema(fields, partial=partial)
if hasattr(Meta, "title"):
jsonschema["title"] = Meta.title
if hasattr(Meta, "description"):
jsonschema["description"] = Meta.description
if hasattr(Meta, "unknown") and Meta.unknown != marshmallow.EXCLUDE:
jsonschema["additionalProperties"] = Meta.unknown == marshmallow.INCLUDE
return jsonschema
def fields2jsonschema(self, fields, *, partial=None):
"""Return the JSON Schema Object given a mapping between field names and
:class:`Field <marshmallow.Field>` objects.
:param dict fields: A dictionary of field name field object pairs
:param bool|tuple partial: Whether to override a field's required flag.
If `True` no fields will be set as required. If an iterable fields
in the iterable will not be marked as required.
:rtype: dict, a JSON Schema Object
"""
jsonschema = {"type": "object", "properties": {}}
for field_name, field_obj in fields.items():
observed_field_name = field_obj.data_key or field_name
prop = self.field2property(field_obj)
jsonschema["properties"][observed_field_name] = prop
if field_obj.required:
if not partial or (
is_collection(partial) and field_name not in partial
):
jsonschema.setdefault("required", []).append(observed_field_name)
if "required" in jsonschema:
jsonschema["required"].sort()
return jsonschema
def get_ref_dict(self, schema):
"""Method to create a dictionary containing a JSON reference to the
schema in the spec
"""
schema_key = make_schema_key(schema)
ref_schema = self.spec.components.get_ref("schema", self.refs[schema_key])
if getattr(schema, "many", False):
return {"type": "array", "items": ref_schema}
return ref_schema