Skip to content

Commit

Permalink
Add Blueprint.alt_response
Browse files Browse the repository at this point in the history
  • Loading branch information
lafrech committed Oct 12, 2020
1 parent e690092 commit d6c4683
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 8 deletions.
75 changes: 69 additions & 6 deletions flask_smorest/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def response(
resp_doc['examples'] = examples
if headers is not None:
resp_doc['headers'] = headers
doc = {'responses': {code: resp_doc}}

def decorator(func):

Expand Down Expand Up @@ -102,13 +101,15 @@ def wrapper(*args, **kwargs):

return resp

# Document default error response
doc['responses']['default'] = 'DEFAULT_ERROR'

# Store doc in wrapper function
# The deepcopy avoids modifying the wrapped function doc
wrapper._apidoc = deepcopy(getattr(wrapper, '_apidoc', {}))
wrapper._apidoc['response'] = doc
# Document default error response
wrapper._apidoc.setdefault(
'response', {}
).setdefault('responses', {})[code] = resp_doc
wrapper._apidoc[
'response']['responses']['default'] = 'DEFAULT_ERROR'
# Indicate which code is the success status code
# Helps other decorators documenting success response
wrapper._apidoc['success_status_code'] = code
Expand All @@ -117,9 +118,66 @@ def wrapper(*args, **kwargs):

return decorator

def alt_response(
self, code, schema_or_ref, *, description=None,
example=None, examples=None, headers=None
):
"""Decorator documenting an alternative response
:param int|str|HTTPStatus code: HTTP status code.
:param schema_or_ref: Either a :class:`Schema <marshmallow.Schema>`
class or instance or a string error reference.
When passing a reference, arguments below are ignored.
:param str description: Description of the response (default: None).
:param dict example: Example of response message.
:param list examples: Examples of response message.
:param dict headers: Headers returned by the response.
"""
# If a ref is passed
if isinstance(schema_or_ref, str):
resp_doc = schema_or_ref
# If a schema is passed
else:
schema = schema_or_ref
if isinstance(schema, type):
schema = schema()

# Document response (schema, description,...) in the API doc
resp_doc = {}
doc_schema = self._make_doc_alternate_response_schema(schema)
if doc_schema is not None:
resp_doc['schema'] = doc_schema
if description is not None:
resp_doc['description'] = description
else:
resp_doc['description'] = http.HTTPStatus(int(code)).phrase
if example is not None:
resp_doc['example'] = example
if examples is not None:
resp_doc['examples'] = examples
if headers is not None:
resp_doc['headers'] = headers

def decorator(func):

@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

# Store doc in wrapper function
# The deepcopy avoids modifying the wrapped function doc
wrapper._apidoc = deepcopy(getattr(wrapper, '_apidoc', {}))
wrapper._apidoc.setdefault(
'response', {}
).setdefault('responses', {})[code] = resp_doc

return wrapper

return decorator

@staticmethod
def _make_doc_response_schema(schema):
"""Override this to modify schema in docs
"""Override this to modify response schema in docs
This can be used to document a wrapping structure.
Expand All @@ -137,6 +195,11 @@ def _make_doc_response_schema(schema):
"""
return schema

@staticmethod
def _make_doc_alternate_response_schema(schema):
"""Override this to modify alternate response schema in docs"""
return schema

@staticmethod
def _prepare_response_content(data):
"""Override this to modify the data structure
Expand Down
9 changes: 7 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ class Meta:
arg1 = ma.fields.String()
arg2 = ma.fields.Integer()

class ClientErrorSchema(ma.Schema):
error_id = ma.fields.Str()
text = ma.fields.Str()

return namedtuple(
'Model', ('DocSchema', 'DocEtagSchema', 'QueryArgsSchema'))(
DocSchema, DocEtagSchema, QueryArgsSchema)
'Model',
('DocSchema', 'DocEtagSchema', 'QueryArgsSchema', 'ClientErrorSchema')
)(DocSchema, DocEtagSchema, QueryArgsSchema, ClientErrorSchema)
166 changes: 166 additions & 0 deletions tests/test_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,172 @@ def func():
api.spec, 'response', 'DEFAULT_ERROR'
)

@pytest.mark.parametrize('openapi_version', ['2.0', '3.0.2'])
def test_blueprint_alt_response_schema(
self, app, openapi_version, schemas
):
"""Check alternate response schema is correctly documented"""
app.config['OPENAPI_VERSION'] = openapi_version
api = Api(app)
blp = Blueprint('test', 'test', url_prefix='/test')

example = {'error_id': 'E1', 'text': 'client error'}
examples = {
'example 1': {'error_id': 'E1', 'text': 'client error 1'},
'example 2': {'error_id': 'E2', 'text': 'client error 2'},
}
headers = {'X-Custom-Header': 'Header value'}

@blp.route('/')
@blp.alt_response(400, schemas.ClientErrorSchema)
def func():
pass

@blp.route('/description')
@blp.alt_response(
400, schemas.ClientErrorSchema, description='Client error'
)
def func_with_description():
pass

@blp.route('/example')
@blp.alt_response(400, schemas.ClientErrorSchema, example=example)
def func_with_example():
pass

if openapi_version == '3.0.2':
@blp.route('/examples')
@blp.alt_response(
400, schemas.ClientErrorSchema, examples=examples
)
def func_with_examples():
pass

@blp.route('/headers')
@blp.alt_response(400, schemas.ClientErrorSchema, headers=headers)
def func_with_headers():
pass

api.register_blueprint(blp)

paths = api.spec.to_dict()['paths']

schema_ref = build_ref(api.spec, 'schema', 'ClientError')

response = paths['/test/']['get']['responses']['400']
if openapi_version == '2.0':
assert response['schema'] == schema_ref
else:
assert (
response['content']['application/json']['schema'] ==
schema_ref
)
assert response['description'] == http.HTTPStatus(400).phrase

response = paths['/test/description']['get']['responses']['400']
assert response['description'] == 'Client error'

response = paths['/test/example']['get']['responses']['400']
if openapi_version == '2.0':
assert response['examples']['application/json'] == example
else:
assert (
response['content']['application/json']['example'] == example
)

if openapi_version == '3.0.2':
response = paths['/test/examples']['get']['responses']['400']
assert (
response['content']['application/json']['examples'] == examples
)

response = paths['/test/headers']['get']['responses']['400']
assert response['headers'] == headers

@pytest.mark.parametrize('openapi_version', ['2.0', '3.0.2'])
def test_blueprint_alt_response_ref(self, app, openapi_version):
"""Check alternate response passed as reference"""
app.config['OPENAPI_VERSION'] = openapi_version
api = Api(app)
api.spec.components.response('ClientErrorResponse')

blp = Blueprint('test', 'test', url_prefix='/test')

@blp.route('/')
@blp.alt_response(400, "ClientErrorResponse")
def func():
pass

api.register_blueprint(blp)

paths = api.spec.to_dict()['paths']

response_ref = build_ref(api.spec, 'response', 'ClientErrorResponse')

assert paths['/test/']['get']['responses']['400'] == response_ref

@pytest.mark.parametrize('openapi_version', ['2.0', '3.0.2'])
def test_blueprint_multiple_alt_response(
self, app, openapi_version, schemas
):
"""Check multiple nested calls to alt_response"""
app.config['OPENAPI_VERSION'] = openapi_version
api = Api(app)
blp = Blueprint('test', 'test', url_prefix='/test')

@blp.route('/')
@blp.alt_response(400, schemas.ClientErrorSchema)
@blp.alt_response(404, 'NotFoundErrorResponse')
def func():
pass

api.register_blueprint(blp)

paths = api.spec.to_dict()['paths']

schema_ref = build_ref(api.spec, 'schema', 'ClientError')
response_ref = build_ref(api.spec, 'response', 'NotFoundErrorResponse')

response = paths['/test/']['get']['responses']['400']
if openapi_version == '2.0':
assert response['schema'] == schema_ref
else:
assert (
response['content']['application/json']['schema'] ==
schema_ref
)

assert paths['/test/']['get']['responses']['404'] == response_ref

@pytest.mark.parametrize('openapi_version', ['2.0', '3.0.2'])
def test_blueprint_alt_response_wrapper(
self, app, schemas, openapi_version
):
"""Check alt_responses passes response transparently"""
app.config['OPENAPI_VERSION'] = openapi_version
api = Api(app)
api.spec.components.response('ClientErrorResponse')

blp = Blueprint('test', 'test', url_prefix='/test')
client = app.test_client()

@blp.route('/')
@blp.response(200, schemas.DocSchema)
@blp.alt_response(400, "ClientErrorResponse")
def func():
return {'item_id': 12}

api.register_blueprint(blp)

paths = api.spec.to_dict()['paths']

response_ref = build_ref(api.spec, 'response', 'ClientErrorResponse')

assert paths['/test/']['get']['responses']['400'] == response_ref

resp = client.get('test/')
assert resp.json == {'item_id': 12}

@pytest.mark.parametrize('openapi_version', ('2.0', '3.0.2'))
def test_blueprint_pagination(self, app, schemas, openapi_version):
app.config['OPENAPI_VERSION'] = openapi_version
Expand Down

0 comments on commit d6c4683

Please sign in to comment.