Skip to content

Commit

Permalink
Add Blueprint.error
Browse files Browse the repository at this point in the history
  • Loading branch information
lafrech committed Oct 12, 2020
1 parent e06325d commit 9df5ce0
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 2 deletions.
61 changes: 61 additions & 0 deletions flask_smorest/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,62 @@ def wrapper(*args, **kwargs):

return decorator

def error(
self, code, schema_or_ref, *, description=None,
example=None, examples=None, headers=None
):
"""Decorator documenting an endpoint error response
:param int|str|HTTPStatus code: HTTP status code.
:param schema_or_ref: Either a :class:`Schema <marshmallow.Schema>`
class or instance or a string error reference.
When passing a reference, arguments below are ignored.
:param str description: Description of the response (default: None).
:param dict example: Example of response message.
:param list examples: Examples of response message.
:param dict headers: Headers returned by the response.
"""
# If a ref is passed
if isinstance(schema_or_ref, str):
doc = {'responses': {code: schema_or_ref}}
# If a schema is passed
else:
schema = schema_or_ref
if isinstance(schema, type):
schema = schema()

# Document response (schema, description,...) in the API doc
resp_doc = {}
doc_schema = self._make_doc_error_response_schema(schema)
if doc_schema is not None:
resp_doc['schema'] = doc_schema
if description is not None:
resp_doc['description'] = description
else:
resp_doc['description'] = http.HTTPStatus(int(code)).phrase
if example is not None:
resp_doc['example'] = example
if examples is not None:
resp_doc['examples'] = examples
if headers is not None:
resp_doc['headers'] = headers
doc = {'responses': {code: resp_doc}}

def decorator(func):

@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

# Store doc in wrapper function
# The deepcopy avoids modifying the wrapped function doc
wrapper._apidoc = deepcopy(getattr(wrapper, '_apidoc', {}))
wrapper._apidoc['response'] = doc

return wrapper

return decorator

@staticmethod
def _make_doc_response_schema(schema):
"""Override this to modify schema in docs
Expand All @@ -137,6 +193,11 @@ def _make_doc_response_schema(schema):
"""
return schema

@staticmethod
def _make_doc_error_response_schema(schema):
"""Override this to modify error schema in docs"""
return schema

@staticmethod
def _prepare_response_content(data):
"""Override this to modify the data structure
Expand Down
9 changes: 7 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ class Meta:
arg1 = ma.fields.String()
arg2 = ma.fields.Integer()

class ClientErrorSchema(ma.Schema):
error_id = ma.fields.Str()
text = ma.fields.Str()

return namedtuple(
'Model', ('DocSchema', 'DocEtagSchema', 'QueryArgsSchema'))(
DocSchema, DocEtagSchema, QueryArgsSchema)
'Model',
('DocSchema', 'DocEtagSchema', 'QueryArgsSchema', 'ClientErrorSchema')
)(DocSchema, DocEtagSchema, QueryArgsSchema, ClientErrorSchema)
90 changes: 90 additions & 0 deletions tests/test_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,96 @@ def func():
api.spec, 'response', 'DEFAULT_ERROR'
)

@pytest.mark.parametrize('openapi_version', ['2.0', '3.0.2'])
def test_blueprint_error_schema(self, app, openapi_version, schemas):
"""Check error schema is correctly documented"""
app.config['OPENAPI_VERSION'] = openapi_version
api = Api(app)
blp = Blueprint('test', 'test', url_prefix='/test')

example = {'error_id': 'E1', 'text': 'client error'}
examples = {
'example 1': {'error_id': 'E1', 'text': 'client error 1'},
'example 2': {'error_id': 'E2', 'text': 'client error 2'},
}

@blp.route('/')
@blp.error(400, schemas.ClientErrorSchema)
def func():
pass

@blp.route('/description')
@blp.error(400, schemas.ClientErrorSchema, description='Client error')
def func_with_description():
pass

@blp.route('/example')
@blp.error(400, schemas.ClientErrorSchema, example=example)
def func_with_example():
pass

if openapi_version == '3.0.2':
@blp.route('/examples')
@blp.error(400, schemas.ClientErrorSchema, examples=examples)
def func_with_examples():
pass

api.register_blueprint(blp)

paths = api.spec.to_dict()['paths']

schema_ref = build_ref(api.spec, 'schema', 'ClientError')

response = paths['/test/']['get']['responses']['400']
if openapi_version == '2.0':
assert response['schema'] == schema_ref
else:
assert (
response['content']['application/json']['schema'] ==
schema_ref
)
assert response['description'] == http.HTTPStatus(400).phrase

response = paths['/test/description']['get']['responses']['400']
assert response['description'] == 'Client error'

response = paths['/test/example']['get']['responses']['400']
if openapi_version == '2.0':
assert response['examples']['application/json'] == example
else:
assert (
response['content']['application/json']['example'] == example
)

if openapi_version == '3.0.2':
response = paths['/test/examples']['get']['responses']['400']
assert (
response['content']['application/json']['examples'] == examples
)

@pytest.mark.parametrize('openapi_version', ['2.0', '3.0.2'])
def test_blueprint_error_ref(self, app, openapi_version):
"""Check error passed as reference"""
app.config['OPENAPI_VERSION'] = openapi_version
api = Api(app)
api.spec.components.response('ClientErrorResponse')

blp = Blueprint('test', 'test', url_prefix='/test')

@blp.route('/')
@blp.error(400, "ClientErrorResponse")
def func():
pass

api.register_blueprint(blp)

paths = api.spec.to_dict()['paths']

response_ref = build_ref(api.spec, 'response', 'ClientErrorResponse')

response = paths['/test/']['get']['responses']['400']
assert response == response_ref

@pytest.mark.parametrize('openapi_version', ('2.0', '3.0.2'))
def test_blueprint_pagination(self, app, schemas, openapi_version):
app.config['OPENAPI_VERSION'] = openapi_version
Expand Down

0 comments on commit 9df5ce0

Please sign in to comment.