Skip to content

Commit

Permalink
✨ Add kwargs parsing to validate decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
adriencaccia committed Sep 7, 2020
1 parent 783c6db commit af34e88
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 25 deletions.
67 changes: 42 additions & 25 deletions flask_pydantic/core.py
Expand Up @@ -85,39 +85,47 @@ def validate(
- request.query_params
- request.body_params
Or directly as `kwargs`, if you define them in the decorated function.
`exclude_none` whether to remove None fields from response
`response_many` whether content of response consists of many objects
(e. g. List[BaseModel]). Resulting response will be an array of serialized
models.
`request_body_many` whether response body contains array of given model
(request.body_params then contains list of models i. e. List[BaseModel])
example:
example::
from flask import request
from flask_pydantic import validate
from pydantic import BaseModel
from flask import request
from flask_pydantic import validate
from pydantic import BaseModel
class Query(BaseModel):
query: str
class Query(BaseModel):
query: str
class Body(BaseModel):
color: str
class Body(BaseModel):
color: str
class MyModel(BaseModel):
id: int
color: str
description: str
class MyModel(BaseModel):
id: int
color: str
description: str
...
...
@app.route("/")
@validate(query=Query, body=Body)
def test_route():
query = request.query_params.query
color = request.body_params.query
@app.route("/")
@validate(query=Query, body=Body)
def test_route():
query = request.query_params.query
color = request.body_params.query
return MyModel(...)
return MyModel(...)
@app.route("/kwargs")
@validate()
def test_route_kwargs(query:Query, body:Body):
return MyModel(...)
-> that will render JSON response with serialized MyModel instance
"""
Expand All @@ -126,22 +134,26 @@ def decorate(func: Callable[[InputParams], Any]) -> Callable[[InputParams], Any]
@wraps(func)
def wrapper(*args, **kwargs):
q, b, err = None, None, {}
if query:
query_params = convert_query_params(request.args, query)
query_in_kwargs = func.__annotations__.get("query")
query_model = query_in_kwargs or query
if query_model:
query_params = convert_query_params(request.args, query_model)
try:
q = query(**query_params)
q = query_model(**query_params)
except ValidationError as ve:
err["query_params"] = ve.errors()
if body:
body_in_kwargs = func.__annotations__.get("body")
body_model = body_in_kwargs or body
if body_model:
body_params = request.get_json()
if request_body_many:
try:
b = validate_many_models(body, body_params)
b = validate_many_models(body_model, body_params)
except ManyModelValidationError as e:
err["body_params"] = e.errors()
else:
try:
b = body(**body_params)
b = body_model(**body_params)
except TypeError:
content_type = request.headers.get("Content-Type", "").lower()
if content_type != "application/json":
Expand All @@ -152,6 +164,11 @@ def wrapper(*args, **kwargs):
err["body_params"] = ve.errors()
request.query_params = q
request.body_params = b
if query_in_kwargs:
kwargs["query"] = q
if body_in_kwargs:
kwargs["body"] = b

if err:
status_code = current_app.config.get(
"FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE", 400
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Expand Up @@ -85,4 +85,14 @@ def post():
]
return response_model(results=results[: query_params.limit], count=len(results))

@app.route("/search/kwargs", methods=["POST"])
@validate()
def post_kwargs(query: query_model, body: body_model):
results = [
post_model(**p)
for p in posts
if pass_search(p, body.search_term, body.exclude, query.min_views)
]
return response_model(results=results[: query.limit], count=len(results))

return app
6 changes: 6 additions & 0 deletions tests/func/test_app.py
Expand Up @@ -87,6 +87,12 @@ def test_post(self, client, query, body, expected_status, expected_response):
assert response.json == expected_response
assert response.status_code == expected_status

@pytest.mark.parametrize("query,body,expected_status,expected_response", test_cases)
def test_post_kwargs(self, client, query, body, expected_status, expected_response):
response = client.post(f"/search/kwargs{query}", json=body)
assert response.json == expected_response
assert response.status_code == expected_status

def test_error_status_code(self, app, mocker, client):
mocker.patch.dict(
app.config, {"FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE": 422}
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/test_core.py
Expand Up @@ -179,6 +179,33 @@ def f():
exclude_none=True, exclude_defaults=True
) == parameters.request_query.to_dict(flat=True)

@pytest.mark.parametrize("parameters", validate_test_cases)
def test_validate_kwargs(self, mocker, request_ctx, parameters: ValidateParams):
mock_request = mocker.patch.object(request_ctx, "request")
mock_request.args = parameters.request_query
mock_request.get_json = lambda: parameters.request_body

def f(body: parameters.body_model, query: parameters.query_model):
return parameters.response_model(**body.dict(), **query.dict())

response = validate(
on_success_status=parameters.on_success_status,
exclude_none=parameters.exclude_none,
response_many=parameters.response_many,
request_body_many=parameters.request_body_many,
)(f)()

assert response.status_code == parameters.expected_status_code
assert response.json == parameters.expected_response_body
if 200 <= response.status_code < 300:
assert (
mock_request.body_params.dict(exclude_none=True, exclude_defaults=True)
== parameters.request_body
)
assert mock_request.query_params.dict(
exclude_none=True, exclude_defaults=True
) == parameters.request_query.to_dict(flat=True)

@pytest.mark.usefixtures("request_ctx")
def test_response_with_status(self):
expected_status_code = 201
Expand Down

0 comments on commit af34e88

Please sign in to comment.