diff --git a/flask_rest_api/etag.py b/flask_rest_api/etag.py index 49ff47c4..52f9faa1 100644 --- a/flask_rest_api/etag.py +++ b/flask_rest_api/etag.py @@ -85,9 +85,12 @@ def wrapper(*args, **kwargs): # Pass data to use as ETag data if set_etag was not called # If etag_schema is provided, pass raw result rather than # dump, as the dump needs to be done using etag_schema - etag_data = get_appcontext()[ - 'result_dump' if etag_schema is None else 'result_raw' - ] + # If 'result_dump'/'result_raw' is not in appcontext, + # the Etag must have been set manually. Just pass None. + etag_data = get_appcontext().get( + 'result_dump' if etag_schema is None else 'result_raw', + None + ) self._set_etag_in_response(resp, etag_data, etag_schema) return resp diff --git a/flask_rest_api/pagination.py b/flask_rest_api/pagination.py index b8fd6df5..c43dc731 100644 --- a/flask_rest_api/pagination.py +++ b/flask_rest_api/pagination.py @@ -13,12 +13,13 @@ from functools import wraps import json +from werkzeug.datastructures import Headers from flask import request, current_app import marshmallow as ma from webargs.flaskparser import FlaskParser -from .utils import get_appcontext +from .utils import unpack_tuple_response from .compat import MARSHMALLOW_VERSION_MAJOR @@ -166,7 +167,8 @@ def wrapper(*args, **kwargs): kwargs['pagination_parameters'] = page_params # Execute decorated function - result = func(*args, **kwargs) + result, status, headers = unpack_tuple_response( + func(*args, **kwargs)) # Post pagination: use pager class to paginate the result if pager is not None: @@ -182,10 +184,12 @@ def wrapper(*args, **kwargs): page_header = self._make_pagination_header( page_params.page, page_params.page_size, page_params.item_count) - get_appcontext()['headers'][ + if headers is None: + headers = Headers() + headers[ self.PAGINATION_HEADER_FIELD_NAME] = page_header - return result + return result, status, headers return wrapper diff --git a/flask_rest_api/response.py b/flask_rest_api/response.py index bce09e83..1bd5d495 100644 --- a/flask_rest_api/response.py +++ b/flask_rest_api/response.py @@ -2,9 +2,13 @@ from functools import wraps +from werkzeug import BaseResponse from flask import jsonify -from .utils import deepupdate, get_appcontext +from .utils import ( + deepupdate, get_appcontext, + unpack_tuple_response, set_status_and_headers_in_response +) from .compat import MARSHMALLOW_VERSION_MAJOR @@ -16,7 +20,8 @@ def response(self, schema=None, *, code=200, description=''): :param schema: :class:`Schema ` class or instance. If not None, will be used to serialize response data. - :param int code: HTTP status code (default: 200). + :param int code: HTTP status code (default: 200). Used if none is + returned from the view function. :param str descripton: Description of the response. See :doc:`Response `. @@ -36,8 +41,17 @@ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): + appcontext = get_appcontext() + # Execute decorated function - result_raw = func(*args, **kwargs) + result_raw, status, headers = unpack_tuple_response( + func(*args, **kwargs)) + + # If return value is a werkzeug BaseResponse, return it + if isinstance(result_raw, BaseResponse): + set_status_and_headers_in_response( + result_raw, status, headers) + return result_raw # Dump result with schema if specified if schema is None: @@ -48,13 +62,14 @@ def wrapper(*args, **kwargs): result_dump = result_dump[0] # Store result in appcontext (may be used for ETag computation) - get_appcontext()['result_raw'] = result_raw - get_appcontext()['result_dump'] = result_dump + appcontext['result_raw'] = result_raw + appcontext['result_dump'] = result_dump # Build response resp = jsonify(self._prepare_response_content(result_dump)) - resp.headers.extend(get_appcontext()['headers']) - resp.status_code = code + set_status_and_headers_in_response(resp, status, headers) + if status is None: + resp.status_code = code return resp diff --git a/flask_rest_api/utils.py b/flask_rest_api/utils.py index eb3e138d..3c44caf1 100644 --- a/flask_rest_api/utils.py +++ b/flask_rest_api/utils.py @@ -3,6 +3,7 @@ from collections import defaultdict from collections.abc import Mapping +from werkzeug.datastructures import Headers from flask import _app_ctx_stack from apispec.utils import trim_docstring, dedent @@ -62,3 +63,44 @@ def load_info_from_docstring(docstring): if description_lines: info['description'] = dedent('\n'.join(description_lines)) return info + + +# Copied from flask +def unpack_tuple_response(rv): + """Unpack a flask Response tuple""" + + status = headers = None + + # unpack tuple returns + if isinstance(rv, tuple): + len_rv = len(rv) + + # a 3-tuple is unpacked directly + if len_rv == 3: + rv, status, headers = rv + # decide if a 2-tuple has status or headers + elif len_rv == 2: + if isinstance(rv[1], (Headers, dict, tuple, list)): + rv, headers = rv + else: + rv, status = rv + # other sized tuples are not allowed + else: + raise TypeError( + 'The view function did not return a valid response tuple.' + ' The tuple must have the form (body, status, headers),' + ' (body, status), or (body, headers).' + ) + + return rv, status, headers + + +def set_status_and_headers_in_response(response, status, headers): + """Set status and headers in flask Reponse object""" + if headers: + response.headers.extend(headers) + if status is not None: + if isinstance(status, int): + response.status_code = status + else: + response.status = status diff --git a/tests/test_blueprint.py b/tests/test_blueprint.py index 4b34d2b9..95e844e5 100644 --- a/tests/test_blueprint.py +++ b/tests/test_blueprint.py @@ -8,8 +8,7 @@ from flask import jsonify from flask.views import MethodView -from flask_rest_api import Api -from flask_rest_api.blueprint import Blueprint +from flask_rest_api import Api, Blueprint, Page from flask_rest_api.exceptions import InvalidLocationError @@ -481,3 +480,146 @@ def func(): assert 'get' in paths['/test/route_1'] assert 'get' in paths['/test/route_2'] + + def test_blueprint_response_tuple(self, app): + api = Api(app) + blp = Blueprint('test', __name__, url_prefix='/test') + client = app.test_client() + + @blp.route('/response') + @blp.response() + def func_response(): + return {} + + @blp.route('/response_code_int') + @blp.response() + def func_response_code_int(): + return {}, 201 + + @blp.route('/response_code_str') + @blp.response() + def func_response_code_str(): + return {}, '201 CREATED' + + @blp.route('/response_headers') + @blp.response() + def func_response_headers(): + return {}, {'X-header': 'test'} + + @blp.route('/response_code_int_headers') + @blp.response() + def func_response_code_int_headers(): + return {}, 201, {'X-header': 'test'} + + @blp.route('/response_code_str_headers') + @blp.response() + def func_response_code_str_headers(): + return {}, '201 CREATED', {'X-header': 'test'} + + @blp.route('/response_wrong_tuple') + @blp.response() + def func_response_wrong_tuple(): + return {}, 201, {'X-header': 'test'}, 'extra' + + api.register_blueprint(blp) + + response = client.get('/test/response') + assert response.status_code == 200 + assert response.json == {} + response = client.get('/test/response_code_int') + assert response.status_code == 201 + assert response.status == '201 CREATED' + assert response.json == {} + response = client.get('/test/response_code_str') + assert response.status_code == 201 + assert response.status == '201 CREATED' + assert response.json == {} + response = client.get('/test/response_headers') + assert response.status_code == 200 + assert response.json == {} + assert response.headers['X-header'] == 'test' + response = client.get('/test/response_code_int_headers') + assert response.status_code == 201 + assert response.status == '201 CREATED' + assert response.json == {} + assert response.headers['X-header'] == 'test' + response = client.get('/test/response_code_str_headers') + assert response.status_code == 201 + assert response.status == '201 CREATED' + assert response.json == {} + assert response.headers['X-header'] == 'test' + response = client.get('/test/response_wrong_tuple') + assert response.status_code == 500 + + def test_blueprint_pagination_response_tuple(self, app): + api = Api(app) + blp = Blueprint('test', __name__, url_prefix='/test') + client = app.test_client() + + @blp.route('/response') + @blp.response() + @blp.paginate(Page) + def func_response(): + return [1, 2] + + @blp.route('/response_code') + @blp.response() + @blp.paginate(Page) + def func_response_code(): + return [1, 2], 201 + + @blp.route('/response_headers') + @blp.response() + @blp.paginate(Page) + def func_response_headers(): + return [1, 2], {'X-header': 'test'} + + @blp.route('/response_code_headers') + @blp.response() + @blp.paginate(Page) + def func_response_code_headers(): + return [1, 2], 201, {'X-header': 'test'} + + @blp.route('/response_wrong_tuple') + @blp.response() + @blp.paginate(Page) + def func_response_wrong_tuple(): + return [1, 2], 201, {'X-header': 'test'}, 'extra' + + api.register_blueprint(blp) + + response = client.get('/test/response') + assert response.status_code == 200 + assert response.json == [1, 2] + response = client.get('/test/response_code') + assert response.status_code == 201 + assert response.json == [1, 2] + response = client.get('/test/response_headers') + assert response.status_code == 200 + assert response.json == [1, 2] + assert response.headers['X-header'] == 'test' + response = client.get('/test/response_code_headers') + assert response.status_code == 201 + assert response.json == [1, 2] + assert response.headers['X-header'] == 'test' + response = client.get('/test/response_wrong_tuple') + assert response.status_code == 500 + + def test_blueprint_response_response_object(self, app, schemas): + api = Api(app) + blp = Blueprint('test', __name__, url_prefix='/test') + client = app.test_client() + + @blp.route('/response') + # Schema is ignored when response object is returned + @blp.response(schemas.DocSchema, code=200) + def func_response(): + return jsonify({}), 201, {'X-header': 'test'} + + api.register_blueprint(blp) + + response = client.get('/test/response') + assert response.status_code == 201 + assert response.status == '201 CREATED' + assert response.json == {} + assert response.headers['X-header'] == 'test' diff --git a/tests/test_etag.py b/tests/test_etag.py index 0fe2acd4..62bff595 100644 --- a/tests/test_etag.py +++ b/tests/test_etag.py @@ -7,7 +7,7 @@ import pytest -from flask import Response +from flask import jsonify, Response from flask.views import MethodView from flask_rest_api import Api, Blueprint, abort @@ -358,6 +358,26 @@ def test_etag_set_etag_in_response(self, app, schemas, paginate): blp._set_etag_in_response(resp, item, etag_schema) assert resp.get_etag() == (etag_with_schema, False) + def test_etag_response_object(self, app): + api = Api(app) + blp = Blueprint('test', __name__, url_prefix='/test') + client = app.test_client() + + @blp.route('/') + @blp.etag + @blp.response() + def func_response_etag(): + # When the view function returns a Response object, + # the ETag must be specified manually + blp.set_etag('test') + return jsonify({}) + + api.register_blueprint(blp) + + response = client.get('/test/') + assert response.json == {} + assert response.get_etag() == (blp._generate_etag('test'), False) + def test_etag_operations_etag_enabled(self, app_with_etag): client = app_with_etag.test_client()