From d6c4683fd925675d646a4354e396ab8c459bc2af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Tue, 9 Jun 2020 00:13:39 +0200 Subject: [PATCH] Add Blueprint.alt_response --- flask_smorest/response.py | 75 +++++++++++++++-- tests/conftest.py | 9 ++- tests/test_blueprint.py | 166 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 242 insertions(+), 8 deletions(-) diff --git a/flask_smorest/response.py b/flask_smorest/response.py index 2845d28c..13edd6e0 100644 --- a/flask_smorest/response.py +++ b/flask_smorest/response.py @@ -66,7 +66,6 @@ def response( resp_doc['examples'] = examples if headers is not None: resp_doc['headers'] = headers - doc = {'responses': {code: resp_doc}} def decorator(func): @@ -102,13 +101,15 @@ def wrapper(*args, **kwargs): return resp - # Document default error response - doc['responses']['default'] = 'DEFAULT_ERROR' - # Store doc in wrapper function # The deepcopy avoids modifying the wrapped function doc wrapper._apidoc = deepcopy(getattr(wrapper, '_apidoc', {})) - wrapper._apidoc['response'] = doc + # Document default error response + wrapper._apidoc.setdefault( + 'response', {} + ).setdefault('responses', {})[code] = resp_doc + wrapper._apidoc[ + 'response']['responses']['default'] = 'DEFAULT_ERROR' # Indicate which code is the success status code # Helps other decorators documenting success response wrapper._apidoc['success_status_code'] = code @@ -117,9 +118,66 @@ def wrapper(*args, **kwargs): return decorator + def alt_response( + self, code, schema_or_ref, *, description=None, + example=None, examples=None, headers=None + ): + """Decorator documenting an alternative response + + :param int|str|HTTPStatus code: HTTP status code. + :param schema_or_ref: Either a :class:`Schema ` + class or instance or a string error reference. + When passing a reference, arguments below are ignored. + :param str description: Description of the response (default: None). + :param dict example: Example of response message. + :param list examples: Examples of response message. + :param dict headers: Headers returned by the response. + """ + # If a ref is passed + if isinstance(schema_or_ref, str): + resp_doc = schema_or_ref + # If a schema is passed + else: + schema = schema_or_ref + if isinstance(schema, type): + schema = schema() + + # Document response (schema, description,...) in the API doc + resp_doc = {} + doc_schema = self._make_doc_alternate_response_schema(schema) + if doc_schema is not None: + resp_doc['schema'] = doc_schema + if description is not None: + resp_doc['description'] = description + else: + resp_doc['description'] = http.HTTPStatus(int(code)).phrase + if example is not None: + resp_doc['example'] = example + if examples is not None: + resp_doc['examples'] = examples + if headers is not None: + resp_doc['headers'] = headers + + def decorator(func): + + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + # Store doc in wrapper function + # The deepcopy avoids modifying the wrapped function doc + wrapper._apidoc = deepcopy(getattr(wrapper, '_apidoc', {})) + wrapper._apidoc.setdefault( + 'response', {} + ).setdefault('responses', {})[code] = resp_doc + + return wrapper + + return decorator + @staticmethod def _make_doc_response_schema(schema): - """Override this to modify schema in docs + """Override this to modify response schema in docs This can be used to document a wrapping structure. @@ -137,6 +195,11 @@ def _make_doc_response_schema(schema): """ return schema + @staticmethod + def _make_doc_alternate_response_schema(schema): + """Override this to modify alternate response schema in docs""" + return schema + @staticmethod def _prepare_response_content(data): """Override this to modify the data structure diff --git a/tests/conftest.py b/tests/conftest.py index 33edcea4..acf6e81b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -76,6 +76,11 @@ class Meta: arg1 = ma.fields.String() arg2 = ma.fields.Integer() + class ClientErrorSchema(ma.Schema): + error_id = ma.fields.Str() + text = ma.fields.Str() + return namedtuple( - 'Model', ('DocSchema', 'DocEtagSchema', 'QueryArgsSchema'))( - DocSchema, DocEtagSchema, QueryArgsSchema) + 'Model', + ('DocSchema', 'DocEtagSchema', 'QueryArgsSchema', 'ClientErrorSchema') + )(DocSchema, DocEtagSchema, QueryArgsSchema, ClientErrorSchema) diff --git a/tests/test_blueprint.py b/tests/test_blueprint.py index f1db3814..f71373bf 100644 --- a/tests/test_blueprint.py +++ b/tests/test_blueprint.py @@ -637,6 +637,172 @@ def func(): api.spec, 'response', 'DEFAULT_ERROR' ) + @pytest.mark.parametrize('openapi_version', ['2.0', '3.0.2']) + def test_blueprint_alt_response_schema( + self, app, openapi_version, schemas + ): + """Check alternate response schema is correctly documented""" + app.config['OPENAPI_VERSION'] = openapi_version + api = Api(app) + blp = Blueprint('test', 'test', url_prefix='/test') + + example = {'error_id': 'E1', 'text': 'client error'} + examples = { + 'example 1': {'error_id': 'E1', 'text': 'client error 1'}, + 'example 2': {'error_id': 'E2', 'text': 'client error 2'}, + } + headers = {'X-Custom-Header': 'Header value'} + + @blp.route('/') + @blp.alt_response(400, schemas.ClientErrorSchema) + def func(): + pass + + @blp.route('/description') + @blp.alt_response( + 400, schemas.ClientErrorSchema, description='Client error' + ) + def func_with_description(): + pass + + @blp.route('/example') + @blp.alt_response(400, schemas.ClientErrorSchema, example=example) + def func_with_example(): + pass + + if openapi_version == '3.0.2': + @blp.route('/examples') + @blp.alt_response( + 400, schemas.ClientErrorSchema, examples=examples + ) + def func_with_examples(): + pass + + @blp.route('/headers') + @blp.alt_response(400, schemas.ClientErrorSchema, headers=headers) + def func_with_headers(): + pass + + api.register_blueprint(blp) + + paths = api.spec.to_dict()['paths'] + + schema_ref = build_ref(api.spec, 'schema', 'ClientError') + + response = paths['/test/']['get']['responses']['400'] + if openapi_version == '2.0': + assert response['schema'] == schema_ref + else: + assert ( + response['content']['application/json']['schema'] == + schema_ref + ) + assert response['description'] == http.HTTPStatus(400).phrase + + response = paths['/test/description']['get']['responses']['400'] + assert response['description'] == 'Client error' + + response = paths['/test/example']['get']['responses']['400'] + if openapi_version == '2.0': + assert response['examples']['application/json'] == example + else: + assert ( + response['content']['application/json']['example'] == example + ) + + if openapi_version == '3.0.2': + response = paths['/test/examples']['get']['responses']['400'] + assert ( + response['content']['application/json']['examples'] == examples + ) + + response = paths['/test/headers']['get']['responses']['400'] + assert response['headers'] == headers + + @pytest.mark.parametrize('openapi_version', ['2.0', '3.0.2']) + def test_blueprint_alt_response_ref(self, app, openapi_version): + """Check alternate response passed as reference""" + app.config['OPENAPI_VERSION'] = openapi_version + api = Api(app) + api.spec.components.response('ClientErrorResponse') + + blp = Blueprint('test', 'test', url_prefix='/test') + + @blp.route('/') + @blp.alt_response(400, "ClientErrorResponse") + def func(): + pass + + api.register_blueprint(blp) + + paths = api.spec.to_dict()['paths'] + + response_ref = build_ref(api.spec, 'response', 'ClientErrorResponse') + + assert paths['/test/']['get']['responses']['400'] == response_ref + + @pytest.mark.parametrize('openapi_version', ['2.0', '3.0.2']) + def test_blueprint_multiple_alt_response( + self, app, openapi_version, schemas + ): + """Check multiple nested calls to alt_response""" + app.config['OPENAPI_VERSION'] = openapi_version + api = Api(app) + blp = Blueprint('test', 'test', url_prefix='/test') + + @blp.route('/') + @blp.alt_response(400, schemas.ClientErrorSchema) + @blp.alt_response(404, 'NotFoundErrorResponse') + def func(): + pass + + api.register_blueprint(blp) + + paths = api.spec.to_dict()['paths'] + + schema_ref = build_ref(api.spec, 'schema', 'ClientError') + response_ref = build_ref(api.spec, 'response', 'NotFoundErrorResponse') + + response = paths['/test/']['get']['responses']['400'] + if openapi_version == '2.0': + assert response['schema'] == schema_ref + else: + assert ( + response['content']['application/json']['schema'] == + schema_ref + ) + + assert paths['/test/']['get']['responses']['404'] == response_ref + + @pytest.mark.parametrize('openapi_version', ['2.0', '3.0.2']) + def test_blueprint_alt_response_wrapper( + self, app, schemas, openapi_version + ): + """Check alt_responses passes response transparently""" + app.config['OPENAPI_VERSION'] = openapi_version + api = Api(app) + api.spec.components.response('ClientErrorResponse') + + blp = Blueprint('test', 'test', url_prefix='/test') + client = app.test_client() + + @blp.route('/') + @blp.response(200, schemas.DocSchema) + @blp.alt_response(400, "ClientErrorResponse") + def func(): + return {'item_id': 12} + + api.register_blueprint(blp) + + paths = api.spec.to_dict()['paths'] + + response_ref = build_ref(api.spec, 'response', 'ClientErrorResponse') + + assert paths['/test/']['get']['responses']['400'] == response_ref + + resp = client.get('test/') + assert resp.json == {'item_id': 12} + @pytest.mark.parametrize('openapi_version', ('2.0', '3.0.2')) def test_blueprint_pagination(self, app, schemas, openapi_version): app.config['OPENAPI_VERSION'] = openapi_version