Skip to content

Commit

Permalink
match flask pattern to return extra headers
Browse files Browse the repository at this point in the history
Technically Flask allows you to return from a route with a tuple of 3
items with the last item being a dictionary of headers. The docs mention
it here: https://flask.palletsprojects.com/en/1.1.x/quickstart/#about-responses
This adjusts the behavior of the validate decorator to allow this as
well. Matches the behavior in: https://github.com/pallets/flask/blob/64213fc0214c1044fa2c9e60d0e2683e75d125c0/src/flask/app.py#L1644-L1646
  • Loading branch information
cardoe committed May 13, 2022
1 parent c801aa1 commit 27d5c82
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
19 changes: 16 additions & 3 deletions flask_pydantic/core.py
Expand Up @@ -234,15 +234,28 @@ def wrapper(*args, **kwargs):

if (
isinstance(res, tuple)
and len(res) == 2
and len(res) in [2, 3]
and isinstance(res[0], BaseModel)
):
return make_json_response(
headers = None
status = on_success_status
if isinstance(res[1], (dict, tuple, list)):
headers = res[1]
elif len(res) == 3 and isinstance(res[2], (dict, tuple, list)):
status = res[1]
headers = res[2]
else:
status = res[1]

ret = make_json_response(
res[0],
res[1],
status,
exclude_none=exclude_none,
by_alias=response_by_alias,
)
if headers:
ret.headers.update(headers)
return ret

return res

Expand Down
16 changes: 16 additions & 0 deletions tests/func/test_app.py
Expand Up @@ -75,6 +75,14 @@ def root_type(body: PersonBulk):
return {"number": len(body)}


@pytest.fixture
def app_with_custom_headers(app):
@app.route("/custom_headers", methods=["GET"])
@validate()
def custom_headers():
return {"test": 1}, 200, {"CUSTOM_HEADER": "UNIQUE"}


@pytest.fixture
def app_with_camel_route(app):
def to_camel(x: str) -> str:
Expand Down Expand Up @@ -190,6 +198,14 @@ def test_custom_root_types(client):
assert response.json == {"number": 2}


@pytest.mark.usefixtures("app_with_custom_headers")
def test_custom_headers(client):
response = client.get("/custom_headers")
assert response.json == {"test": 1}
assert response.status_code == 200
assert response.headers.get("CUSTOM_HEADER") == "UNIQUE"


@pytest.mark.usefixtures("app_with_array_route")
class TestArrayQueryParam:
def test_no_param_raises(self, client):
Expand Down

0 comments on commit 27d5c82

Please sign in to comment.