Skip to content

Commit

Permalink
Merge pull request #376 from ThiefMaster/customizable-schema
Browse files Browse the repository at this point in the history
Make schema class configurable
  • Loading branch information
sloria committed Mar 16, 2019
2 parents 86564f1 + 06e7b5c commit c60472b
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 8 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
Changelog
---------

5.2.0 (unreleased)
******************

Features:

* Make the schema class used when generating a schema from a
dict overridable (:issue:`375`)

5.1.3 (2019-03-11)
******************

Expand Down
2 changes: 1 addition & 1 deletion src/webargs/asyncparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def use_args(
# Optimization: If argmap is passed as a dictionary, we only need
# to generate a Schema once
if isinstance(argmap, Mapping):
argmap = core.dict2schema(argmap)()
argmap = core.dict2schema(argmap, self.schema_class)()

def decorator(func: typing.Callable) -> typing.Callable:
req_ = request_obj
Expand Down
14 changes: 9 additions & 5 deletions src/webargs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _callable_or_raise(obj):
return obj


def dict2schema(dct):
def dict2schema(dct, schema_class=ma.Schema):
"""Generate a `marshmallow.Schema` class given a dictionary of
`Fields <marshmallow.fields.Field>`.
"""
Expand All @@ -64,7 +64,7 @@ class Meta(object):
register = False

attrs["Meta"] = Meta
return type(str(""), (ma.Schema,), attrs)
return type(str(""), (schema_class,), attrs)


def is_multiple(field):
Expand Down Expand Up @@ -162,7 +162,10 @@ class Parser(object):
:param callable error_handler: Custom error handler function.
"""

#: Default locations to check for data
DEFAULT_LOCATIONS = ("querystring", "form", "json")
#: The marshmallow Schema class to use when creating new schemas
DEFAULT_SCHEMA_CLASS = ma.Schema
#: Default status code to return for validation errors
DEFAULT_VALIDATION_STATUS = DEFAULT_VALIDATION_STATUS
#: Default error message for validation errors
Expand All @@ -179,9 +182,10 @@ class Parser(object):
"files": "parse_files",
}

def __init__(self, locations=None, error_handler=None):
def __init__(self, locations=None, error_handler=None, schema_class=None):
self.locations = locations or self.DEFAULT_LOCATIONS
self.error_callback = _callable_or_raise(error_handler)
self.schema_class = schema_class or self.DEFAULT_SCHEMA_CLASS
#: A short-lived cache to store results from processing request bodies.
self._cache = {}

Expand Down Expand Up @@ -309,7 +313,7 @@ def _get_schema(self, argmap, req):
elif callable(argmap):
schema = argmap(req)
else:
schema = dict2schema(argmap)()
schema = dict2schema(argmap, self.schema_class)()
if MARSHMALLOW_VERSION_INFO[0] < 3 and not schema.strict:
warnings.warn(
"It is highly recommended that you set strict=True on your schema "
Expand Down Expand Up @@ -439,7 +443,7 @@ def greet(args):
# Optimization: If argmap is passed as a dictionary, we only need
# to generate a Schema once
if isinstance(argmap, Mapping):
argmap = dict2schema(argmap)()
argmap = dict2schema(argmap, self.schema_class)()

def decorator(func):
req_ = request_obj
Expand Down
4 changes: 4 additions & 0 deletions src/webargs/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
class Nested(ma.fields.Nested):
"""Same as `marshmallow.fields.Nested`, except can be passed a dictionary as
the first argument, which will be converted to a `marshmallow.Schema`.
Note: The schema class here will always be `marshmallow.Schema`, regardless
of whether a custom schema class is set on the parser. Pass an explicit schema
class if necessary.
"""

def __init__(self, nested, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion src/webargs/pyramidparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def use_args(
# Optimization: If argmap is passed as a dictionary, we only need
# to generate a Schema once
if isinstance(argmap, collections.Mapping):
argmap = core.dict2schema(argmap)()
argmap = core.dict2schema(argmap, self.schema_class)()

def decorator(func):
@functools.wraps(func)
Expand Down
35 changes: 34 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import datetime

import pytest
from marshmallow import Schema, post_load, class_registry, validates_schema
from marshmallow import Schema, post_load, pre_load, class_registry, validates_schema
from werkzeug.datastructures import MultiDict as WerkMultiDict
from django.utils.datastructures import MultiValueDict as DjMultiDict
from bottle import MultiDict as BotMultiDict
Expand Down Expand Up @@ -1052,3 +1052,36 @@ def test_parse_with_error_status_code_and_headers(web_request):
error = excinfo.value
assert error.status_code == 418
assert error.headers == {"X-Foo": "bar"}


@mock.patch("webargs.core.Parser.parse_json")
def test_custom_schema_class(parse_json, web_request):
class CustomSchema(Schema):
@pre_load
def pre_load(self, data):
data["value"] += " world"
return data

parse_json.return_value = "hello"
argmap = {"value": fields.Str()}
p = Parser(schema_class=CustomSchema)
ret = p.parse(argmap, web_request)
assert ret == {"value": "hello world"}


@mock.patch("webargs.core.Parser.parse_json")
def test_custom_default_schema_class(parse_json, web_request):
class CustomSchema(Schema):
@pre_load
def pre_load(self, data):
data["value"] += " world"
return data

class CustomParser(Parser):
DEFAULT_SCHEMA_CLASS = CustomSchema

parse_json.return_value = "hello"
argmap = {"value": fields.Str()}
p = CustomParser()
ret = p.parse(argmap, web_request)
assert ret == {"value": "hello world"}

0 comments on commit c60472b

Please sign in to comment.