From 27d5c82fea0bc0bceba1b2e3b398d68962eae56f Mon Sep 17 00:00:00 2001 From: Doug Goldstein Date: Thu, 15 Apr 2021 15:50:31 -0500 Subject: [PATCH] match flask pattern to return extra headers Technically Flask allows you to return from a route with a tuple of 3 items with the last item being a dictionary of headers. The docs mention it here: https://flask.palletsprojects.com/en/1.1.x/quickstart/#about-responses This adjusts the behavior of the validate decorator to allow this as well. Matches the behavior in: https://github.com/pallets/flask/blob/64213fc0214c1044fa2c9e60d0e2683e75d125c0/src/flask/app.py#L1644-L1646 --- flask_pydantic/core.py | 19 ++++++++++++++++--- tests/func/test_app.py | 16 ++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/flask_pydantic/core.py b/flask_pydantic/core.py index f34dd0d..4c261a7 100644 --- a/flask_pydantic/core.py +++ b/flask_pydantic/core.py @@ -234,15 +234,28 @@ def wrapper(*args, **kwargs): if ( isinstance(res, tuple) - and len(res) == 2 + and len(res) in [2, 3] and isinstance(res[0], BaseModel) ): - return make_json_response( + headers = None + status = on_success_status + if isinstance(res[1], (dict, tuple, list)): + headers = res[1] + elif len(res) == 3 and isinstance(res[2], (dict, tuple, list)): + status = res[1] + headers = res[2] + else: + status = res[1] + + ret = make_json_response( res[0], - res[1], + status, exclude_none=exclude_none, by_alias=response_by_alias, ) + if headers: + ret.headers.update(headers) + return ret return res diff --git a/tests/func/test_app.py b/tests/func/test_app.py index 8e21687..b667666 100644 --- a/tests/func/test_app.py +++ b/tests/func/test_app.py @@ -75,6 +75,14 @@ def root_type(body: PersonBulk): return {"number": len(body)} +@pytest.fixture +def app_with_custom_headers(app): + @app.route("/custom_headers", methods=["GET"]) + @validate() + def custom_headers(): + return {"test": 1}, 200, {"CUSTOM_HEADER": "UNIQUE"} + + @pytest.fixture def app_with_camel_route(app): def to_camel(x: str) -> str: @@ -190,6 +198,14 @@ def test_custom_root_types(client): assert response.json == {"number": 2} +@pytest.mark.usefixtures("app_with_custom_headers") +def test_custom_headers(client): + response = client.get("/custom_headers") + assert response.json == {"test": 1} + assert response.status_code == 200 + assert response.headers.get("CUSTOM_HEADER") == "UNIQUE" + + @pytest.mark.usefixtures("app_with_array_route") class TestArrayQueryParam: def test_no_param_raises(self, client):