Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow view functions to return tuple or Response class #40

Merged
merged 5 commits into from
Mar 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions flask_rest_api/etag.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,12 @@ def wrapper(*args, **kwargs):
# Pass data to use as ETag data if set_etag was not called
# If etag_schema is provided, pass raw result rather than
# dump, as the dump needs to be done using etag_schema
etag_data = get_appcontext()[
'result_dump' if etag_schema is None else 'result_raw'
]
# If 'result_dump'/'result_raw' is not in appcontext,
# the Etag must have been set manually. Just pass None.
etag_data = get_appcontext().get(
'result_dump' if etag_schema is None else 'result_raw',
None
)
self._set_etag_in_response(resp, etag_data, etag_schema)

return resp
Expand Down
11 changes: 7 additions & 4 deletions flask_rest_api/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import marshmallow as ma
from webargs.flaskparser import FlaskParser

from .utils import get_appcontext
from .utils import unpack_tuple_response
from .compat import MARSHMALLOW_VERSION_MAJOR


Expand Down Expand Up @@ -166,7 +166,8 @@ def wrapper(*args, **kwargs):
kwargs['pagination_parameters'] = page_params

# Execute decorated function
result = func(*args, **kwargs)
result, status, headers = unpack_tuple_response(
func(*args, **kwargs))

# Post pagination: use pager class to paginate the result
if pager is not None:
Expand All @@ -182,10 +183,12 @@ def wrapper(*args, **kwargs):
page_header = self._make_pagination_header(
page_params.page, page_params.page_size,
page_params.item_count)
get_appcontext()['headers'][
if headers is None:
headers = {}
headers[
self.PAGINATION_HEADER_FIELD_NAME] = page_header

return result
return result, status, headers

return wrapper

Expand Down
28 changes: 21 additions & 7 deletions flask_rest_api/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

from functools import wraps

from werkzeug import BaseResponse
from flask import jsonify

from .utils import deepupdate, get_appcontext
from .utils import (
deepupdate, get_appcontext,
unpack_tuple_response, set_status_and_headers_in_response
)
from .compat import MARSHMALLOW_VERSION_MAJOR


Expand All @@ -16,7 +20,8 @@ def response(self, schema=None, *, code=200, description=''):

:param schema: :class:`Schema <marshmallow.Schema>` class or instance.
If not None, will be used to serialize response data.
:param int code: HTTP status code (default: 200).
:param int code: HTTP status code (default: 200). Used if none is
returned from the view function.
:param str descripton: Description of the response.

See :doc:`Response <response>`.
Expand All @@ -37,7 +42,14 @@ def decorator(func):
def wrapper(*args, **kwargs):

# Execute decorated function
result_raw = func(*args, **kwargs)
result_raw, status, headers = unpack_tuple_response(
func(*args, **kwargs))

# If return value is a werkzeug BaseResponse, return it
if isinstance(result_raw, BaseResponse):
set_status_and_headers_in_response(
result_raw, status, headers)
return result_raw

# Dump result with schema if specified
if schema is None:
Expand All @@ -48,13 +60,15 @@ def wrapper(*args, **kwargs):
result_dump = result_dump[0]

# Store result in appcontext (may be used for ETag computation)
get_appcontext()['result_raw'] = result_raw
get_appcontext()['result_dump'] = result_dump
appcontext = get_appcontext()
appcontext['result_raw'] = result_raw
appcontext['result_dump'] = result_dump

# Build response
resp = jsonify(self._prepare_response_content(result_dump))
resp.headers.extend(get_appcontext()['headers'])
resp.status_code = code
set_status_and_headers_in_response(resp, status, headers)
if status is None:
resp.status_code = code

return resp

Expand Down
42 changes: 42 additions & 0 deletions flask_rest_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict
from collections.abc import Mapping

from werkzeug.datastructures import Headers
from flask import _app_ctx_stack
from apispec.utils import trim_docstring, dedent

Expand Down Expand Up @@ -62,3 +63,44 @@ def load_info_from_docstring(docstring):
if description_lines:
info['description'] = dedent('\n'.join(description_lines))
return info


# Copied from flask
def unpack_tuple_response(rv):
"""Unpack a flask Response tuple"""

status = headers = None

# unpack tuple returns
if isinstance(rv, tuple):
len_rv = len(rv)

# a 3-tuple is unpacked directly
if len_rv == 3:
rv, status, headers = rv
# decide if a 2-tuple has status or headers
elif len_rv == 2:
if isinstance(rv[1], (Headers, dict, tuple, list)):
rv, headers = rv
else:
rv, status = rv
# other sized tuples are not allowed
else:
raise TypeError(
'The view function did not return a valid response tuple.'
' The tuple must have the form (body, status, headers),'
' (body, status), or (body, headers).'
)

return rv, status, headers


def set_status_and_headers_in_response(response, status, headers):
"""Set status and headers in flask Reponse object"""
if headers:
response.headers.extend(headers)
if status is not None:
if isinstance(status, int):
response.status_code = status
else:
response.status = status
146 changes: 144 additions & 2 deletions tests/test_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from flask import jsonify
from flask.views import MethodView

from flask_rest_api import Api
from flask_rest_api.blueprint import Blueprint
from flask_rest_api import Api, Blueprint, Page
from flask_rest_api.exceptions import InvalidLocationError


Expand Down Expand Up @@ -481,3 +480,146 @@ def func():

assert 'get' in paths['/test/route_1']
assert 'get' in paths['/test/route_2']

def test_blueprint_response_tuple(self, app):
api = Api(app)
blp = Blueprint('test', __name__, url_prefix='/test')
client = app.test_client()

@blp.route('/response')
@blp.response()
def func_response():
return {}

@blp.route('/response_code_int')
@blp.response()
def func_response_code_int():
return {}, 201

@blp.route('/response_code_str')
@blp.response()
def func_response_code_str():
return {}, '201 CREATED'

@blp.route('/response_headers')
@blp.response()
def func_response_headers():
return {}, {'X-header': 'test'}

@blp.route('/response_code_int_headers')
@blp.response()
def func_response_code_int_headers():
return {}, 201, {'X-header': 'test'}

@blp.route('/response_code_str_headers')
@blp.response()
def func_response_code_str_headers():
return {}, '201 CREATED', {'X-header': 'test'}

@blp.route('/response_wrong_tuple')
@blp.response()
def func_response_wrong_tuple():
return {}, 201, {'X-header': 'test'}, 'extra'

api.register_blueprint(blp)

response = client.get('/test/response')
assert response.status_code == 200
assert response.json == {}
response = client.get('/test/response_code_int')
assert response.status_code == 201
assert response.status == '201 CREATED'
assert response.json == {}
response = client.get('/test/response_code_str')
assert response.status_code == 201
assert response.status == '201 CREATED'
assert response.json == {}
response = client.get('/test/response_headers')
assert response.status_code == 200
assert response.json == {}
assert response.headers['X-header'] == 'test'
response = client.get('/test/response_code_int_headers')
assert response.status_code == 201
assert response.status == '201 CREATED'
assert response.json == {}
assert response.headers['X-header'] == 'test'
response = client.get('/test/response_code_str_headers')
assert response.status_code == 201
assert response.status == '201 CREATED'
assert response.json == {}
assert response.headers['X-header'] == 'test'
response = client.get('/test/response_wrong_tuple')
assert response.status_code == 500

def test_blueprint_pagination_response_tuple(self, app):
api = Api(app)
blp = Blueprint('test', __name__, url_prefix='/test')
client = app.test_client()

@blp.route('/response')
@blp.response()
@blp.paginate(Page)
def func_response():
return [1, 2]

@blp.route('/response_code')
@blp.response()
@blp.paginate(Page)
def func_response_code():
return [1, 2], 201

@blp.route('/response_headers')
@blp.response()
@blp.paginate(Page)
def func_response_headers():
return [1, 2], {'X-header': 'test'}

@blp.route('/response_code_headers')
@blp.response()
@blp.paginate(Page)
def func_response_code_headers():
return [1, 2], 201, {'X-header': 'test'}

@blp.route('/response_wrong_tuple')
@blp.response()
@blp.paginate(Page)
def func_response_wrong_tuple():
return [1, 2], 201, {'X-header': 'test'}, 'extra'

api.register_blueprint(blp)

response = client.get('/test/response')
assert response.status_code == 200
assert response.json == [1, 2]
response = client.get('/test/response_code')
assert response.status_code == 201
assert response.json == [1, 2]
response = client.get('/test/response_headers')
assert response.status_code == 200
assert response.json == [1, 2]
assert response.headers['X-header'] == 'test'
response = client.get('/test/response_code_headers')
assert response.status_code == 201
assert response.json == [1, 2]
assert response.headers['X-header'] == 'test'
response = client.get('/test/response_wrong_tuple')
assert response.status_code == 500

def test_blueprint_response_response_object(self, app, schemas):
api = Api(app)
blp = Blueprint('test', __name__, url_prefix='/test')
client = app.test_client()

@blp.route('/response')
# Schema is ignored when response object is returned
@blp.response(schemas.DocSchema, code=200)
def func_response():
return jsonify({}), 201, {'X-header': 'test'}

api.register_blueprint(blp)

response = client.get('/test/response')
assert response.status_code == 201
assert response.status == '201 CREATED'
assert response.json == {}
assert response.headers['X-header'] == 'test'
22 changes: 21 additions & 1 deletion tests/test_etag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest

from flask import Response
from flask import jsonify, Response
from flask.views import MethodView

from flask_rest_api import Api, Blueprint, abort
Expand Down Expand Up @@ -358,6 +358,26 @@ def test_etag_set_etag_in_response(self, app, schemas, paginate):
blp._set_etag_in_response(resp, item, etag_schema)
assert resp.get_etag() == (etag_with_schema, False)

def test_etag_response_object(self, app):
api = Api(app)
blp = Blueprint('test', __name__, url_prefix='/test')
client = app.test_client()

@blp.route('/')
@blp.etag
@blp.response()
def func_response_etag():
# When the view function returns a Response object,
# the ETag must be specified manually
blp.set_etag('test')
return jsonify({})

api.register_blueprint(blp)

response = client.get('/test/')
assert response.json == {}
assert response.get_etag() == (blp._generate_etag('test'), False)

def test_etag_operations_etag_enabled(self, app_with_etag):

client = app_with_etag.test_client()
Expand Down