From af34e8883fe0d8dd8d7660ad0fb5ba59a0230112 Mon Sep 17 00:00:00 2001 From: Adrien Cacciaguerra Date: Mon, 7 Sep 2020 16:07:08 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20kwargs=20parsing=20to=20valid?= =?UTF-8?q?ate=20decorator?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- flask_pydantic/core.py | 67 ++++++++++++++++++++++++++--------------- tests/conftest.py | 10 ++++++ tests/func/test_app.py | 6 ++++ tests/unit/test_core.py | 27 +++++++++++++++++ 4 files changed, 85 insertions(+), 25 deletions(-) diff --git a/flask_pydantic/core.py b/flask_pydantic/core.py index 07293b5..e108c18 100644 --- a/flask_pydantic/core.py +++ b/flask_pydantic/core.py @@ -85,6 +85,8 @@ def validate( - request.query_params - request.body_params + Or directly as `kwargs`, if you define them in the decorated function. + `exclude_none` whether to remove None fields from response `response_many` whether content of response consists of many objects (e. g. List[BaseModel]). Resulting response will be an array of serialized @@ -92,32 +94,38 @@ def validate( `request_body_many` whether response body contains array of given model (request.body_params then contains list of models i. e. List[BaseModel]) - example: + example:: + + from flask import request + from flask_pydantic import validate + from pydantic import BaseModel - from flask import request - from flask_pydantic import validate - from pydantic import BaseModel + class Query(BaseModel): + query: str - class Query(BaseModel): - query: str + class Body(BaseModel): + color: str - class Body(BaseModel): - color: str + class MyModel(BaseModel): + id: int + color: str + description: str - class MyModel(BaseModel): - id: int - color: str - description: str + ... - ... + @app.route("/") + @validate(query=Query, body=Body) + def test_route(): + query = request.query_params.query + color = request.body_params.query - @app.route("/") - @validate(query=Query, body=Body) - def test_route(): - query = request.query_params.query - color = request.body_params.query + return MyModel(...) - return MyModel(...) + @app.route("/kwargs") + @validate() + def test_route_kwargs(query:Query, body:Body): + + return MyModel(...) -> that will render JSON response with serialized MyModel instance """ @@ -126,22 +134,26 @@ def decorate(func: Callable[[InputParams], Any]) -> Callable[[InputParams], Any] @wraps(func) def wrapper(*args, **kwargs): q, b, err = None, None, {} - if query: - query_params = convert_query_params(request.args, query) + query_in_kwargs = func.__annotations__.get("query") + query_model = query_in_kwargs or query + if query_model: + query_params = convert_query_params(request.args, query_model) try: - q = query(**query_params) + q = query_model(**query_params) except ValidationError as ve: err["query_params"] = ve.errors() - if body: + body_in_kwargs = func.__annotations__.get("body") + body_model = body_in_kwargs or body + if body_model: body_params = request.get_json() if request_body_many: try: - b = validate_many_models(body, body_params) + b = validate_many_models(body_model, body_params) except ManyModelValidationError as e: err["body_params"] = e.errors() else: try: - b = body(**body_params) + b = body_model(**body_params) except TypeError: content_type = request.headers.get("Content-Type", "").lower() if content_type != "application/json": @@ -152,6 +164,11 @@ def wrapper(*args, **kwargs): err["body_params"] = ve.errors() request.query_params = q request.body_params = b + if query_in_kwargs: + kwargs["query"] = q + if body_in_kwargs: + kwargs["body"] = b + if err: status_code = current_app.config.get( "FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE", 400 diff --git a/tests/conftest.py b/tests/conftest.py index 1ce3997..0a7cbf3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -85,4 +85,14 @@ def post(): ] return response_model(results=results[: query_params.limit], count=len(results)) + @app.route("/search/kwargs", methods=["POST"]) + @validate() + def post_kwargs(query: query_model, body: body_model): + results = [ + post_model(**p) + for p in posts + if pass_search(p, body.search_term, body.exclude, query.min_views) + ] + return response_model(results=results[: query.limit], count=len(results)) + return app diff --git a/tests/func/test_app.py b/tests/func/test_app.py index 3ab52ce..78b5bc2 100644 --- a/tests/func/test_app.py +++ b/tests/func/test_app.py @@ -87,6 +87,12 @@ def test_post(self, client, query, body, expected_status, expected_response): assert response.json == expected_response assert response.status_code == expected_status + @pytest.mark.parametrize("query,body,expected_status,expected_response", test_cases) + def test_post_kwargs(self, client, query, body, expected_status, expected_response): + response = client.post(f"/search/kwargs{query}", json=body) + assert response.json == expected_response + assert response.status_code == expected_status + def test_error_status_code(self, app, mocker, client): mocker.patch.dict( app.config, {"FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE": 422} diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index ed77e2e..7252847 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -179,6 +179,33 @@ def f(): exclude_none=True, exclude_defaults=True ) == parameters.request_query.to_dict(flat=True) + @pytest.mark.parametrize("parameters", validate_test_cases) + def test_validate_kwargs(self, mocker, request_ctx, parameters: ValidateParams): + mock_request = mocker.patch.object(request_ctx, "request") + mock_request.args = parameters.request_query + mock_request.get_json = lambda: parameters.request_body + + def f(body: parameters.body_model, query: parameters.query_model): + return parameters.response_model(**body.dict(), **query.dict()) + + response = validate( + on_success_status=parameters.on_success_status, + exclude_none=parameters.exclude_none, + response_many=parameters.response_many, + request_body_many=parameters.request_body_many, + )(f)() + + assert response.status_code == parameters.expected_status_code + assert response.json == parameters.expected_response_body + if 200 <= response.status_code < 300: + assert ( + mock_request.body_params.dict(exclude_none=True, exclude_defaults=True) + == parameters.request_body + ) + assert mock_request.query_params.dict( + exclude_none=True, exclude_defaults=True + ) == parameters.request_query.to_dict(flat=True) + @pytest.mark.usefixtures("request_ctx") def test_response_with_status(self): expected_status_code = 201