Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support dict as Schema #423

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions flask_smorest/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import wraps
import http

import marshmallow as ma
from webargs.flaskparser import FlaskParser

from .utils import deepupdate
Expand All @@ -28,8 +29,8 @@ def arguments(
):
"""Decorator specifying the schema used to deserialize parameters

:param type|Schema schema: Marshmallow ``Schema`` class or instance
used to deserialize and validate the argument.
:param type|Schema|dict schema: Marshmallow ``Schema`` class or instance
or dict used to deserialize and validate the argument.
:param str location: Location of the argument.
:param str content_type: Content type of the argument.
Should only be used in conjunction with ``json``, ``form`` or
Expand All @@ -56,6 +57,8 @@ def arguments(

See :doc:`Arguments <arguments>`.
"""
if isinstance(schema, dict):
schema = ma.Schema.from_dict(schema)
# At this stage, put schema instance in doc dictionary. Il will be
# replaced later on by $ref or json.
parameters = {
Expand Down
10 changes: 9 additions & 1 deletion flask_smorest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections import abc

import marshmallow as ma
from werkzeug.datastructures import Headers
from flask import g
from apispec.utils import trim_docstring, dedent
Expand Down Expand Up @@ -31,9 +32,16 @@ def remove_none(mapping):
def resolve_schema_instance(schema):
"""Return schema instance for given schema (instance or class).

:param type|Schema schema: marshmallow.Schema instance or class
:param type|Schema|dict schema: marshmallow.Schema instance or class or dict
:return: schema instance of given schema
"""

# this dict may be used to document a file response, no a schema dict
if isinstance(schema, dict) and all(
[isinstance(v, (type, ma.fields.Field)) for v in schema.values()]
):
schema = ma.Schema.from_dict(schema)

return schema() if isinstance(schema, type) else schema


Expand Down
11 changes: 8 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class ClientErrorSchema(ma.Schema):
error_id = ma.fields.Str()
text = ma.fields.Str()

return namedtuple("Model", ("DocSchema", "QueryArgsSchema", "ClientErrorSchema"))(
DocSchema, QueryArgsSchema, ClientErrorSchema
)
DictSchema = {
"item_id": ma.fields.Int(dump_only=True),
"field": ma.fields.Int(attribute="db_field"),
}

return namedtuple(
"Model", ("DocSchema", "QueryArgsSchema", "ClientErrorSchema", "DictSchema")
)(DocSchema, QueryArgsSchema, ClientErrorSchema, DictSchema)
59 changes: 59 additions & 0 deletions tests/test_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,65 @@ def func(document, query_args):
"query_args": {"arg1": "test"},
}

@pytest.mark.parametrize("openapi_version", ("2.0", "3.0.2"))
def test_blueprint_dict_argument_schema(self, app, schemas, openapi_version):
app.config["OPENAPI_VERSION"] = openapi_version
api = Api(app)
blp = Blueprint("test", __name__, url_prefix="/test")
client = app.test_client()

@blp.route("/", methods=("POST",))
@blp.arguments(schemas.DictSchema)
def func(document):
return {"document": document}

api.register_blueprint(blp)
spec = api.spec.to_dict()

# Check parameters are documented
if openapi_version == "2.0":
parameters = spec["paths"]["/test/"]["post"]["parameters"]
assert len(parameters) == 1
assert parameters[0]["in"] == "body"
assert "schema" in parameters[0]
else:
assert (
"schema"
in spec["paths"]["/test/"]["post"]["requestBody"]["content"][
"application/json"
]
)

# Check parameters are passed as arguments to view function
item_data = {"field": 12}
response = client.post(
"/test/",
data=json.dumps(item_data),
content_type="application/json",
)
assert response.status_code == 200
assert response.json == {
"document": {"db_field": 12},
}

@pytest.mark.parametrize("openapi_version", ["2.0", "3.0.2"])
def test_blueprint_dict_response_schema(self, app, schemas, openapi_version):
"""Check alt_response passes response transparently"""
app.config["OPENAPI_VERSION"] = openapi_version
api = Api(app)
blp = Blueprint("test", "test", url_prefix="/test")
client = app.test_client()

@blp.route("/")
@blp.response(200, schema=schemas.DictSchema)
def func():
return {"item_id": 12}

api.register_blueprint(blp)

resp = client.get("/test/")
assert resp.json == {"item_id": 12}

@pytest.mark.parametrize("openapi_version", ("2.0", "3.0.2"))
def test_blueprint_arguments_files_multipart(self, app, schemas, openapi_version):
app.config["OPENAPI_VERSION"] = openapi_version
Expand Down