Skip to content

Commit

Permalink
Update pyramidparser for webargs v6
Browse files Browse the repository at this point in the history
Convert parse_* to load_*
This conversion is fairly straightforward, replacing `get_value` calls
with MultiDictProxy instantiations.
  • Loading branch information
sirosen committed Sep 14, 2019
1 parent a9ba898 commit f005c40
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 53 deletions.
73 changes: 32 additions & 41 deletions src/webargs/pyramidparser.py
Expand Up @@ -34,56 +34,47 @@ def hello_world(request, args):
from webargs import core
from webargs.core import json
from webargs.compat import text_type
from webargs.multidictproxy import MultiDictProxy


class PyramidParser(core.Parser):
"""Pyramid request argument parser."""

__location_map__ = dict(
matchdict="parse_matchdict",
path="parse_matchdict",
matchdict="load_matchdict",
path="load_matchdict",
**core.Parser.__location_map__
)

def parse_querystring(self, req, name, field):
"""Pull a querystring value from the request."""
return core.get_value(req.GET, name, field)

def parse_form(self, req, name, field):
"""Pull a form value from the request."""
return core.get_value(req.POST, name, field)

def parse_json(self, req, name, field):
"""Pull a json value from the request."""
json_data = self._cache.get("json")
if json_data is None:
try:
self._cache["json"] = json_data = core.parse_json(req.body, req.charset)
except json.JSONDecodeError as e:
if e.doc == "":
return core.missing
else:
return self.handle_invalid_json_error(e, req)
if json_data is None:
return core.missing
return core.get_value(json_data, name, field, allow_many_nested=True)
def _raw_load_json(self, req):
"""Return a json payload from the request for the core parser's
load_json"""
return core.parse_json(req.body, req.charset)

def load_querystring(self, req, schema):
"""Return query params from the request as a MultiDictProxy."""
return MultiDictProxy(req.GET, schema)

def load_form(self, req, schema):
"""Return form values from the request as a MultiDictProxy."""
return MultiDictProxy(req.POST, schema)

def parse_cookies(self, req, name, field):
"""Pull the value from the cookiejar."""
return core.get_value(req.cookies, name, field)
def load_cookies(self, req, schema):
"""Return cookies from the request as a MultiDictProxy."""
return MultiDictProxy(req.cookies, schema)

def parse_headers(self, req, name, field):
"""Pull a value from the header data."""
return core.get_value(req.headers, name, field)
def load_headers(self, req, schema):
"""Return headers from the request as a MultiDictProxy."""
return MultiDictProxy(req.headers, schema)

def parse_files(self, req, name, field):
"""Pull a file from the request."""
def load_files(self, req, schema):
"""Return files from the request as a MultiDictProxy."""
files = ((k, v) for k, v in req.POST.items() if hasattr(v, "file"))
return core.get_value(MultiDict(files), name, field)
return MultiDictProxy(MultiDict(files), schema)

def parse_matchdict(self, req, name, field):
"""Pull a value from the request's `matchdict`."""
return core.get_value(req.matchdict, name, field)
def load_matchdict(self, req, schema):
"""Return the request's ``matchdict`` as a MultiDictProxy."""
return MultiDictProxy(req.matchdict, schema)

def handle_error(self, error, req, schema, error_status_code, error_headers):
"""Handles errors during parsing. Aborts the current HTTP request and
Expand All @@ -100,7 +91,7 @@ def handle_error(self, error, req, schema, error_status_code, error_headers):
response.body = body.encode("utf-8") if isinstance(body, text_type) else body
raise response

def handle_invalid_json_error(self, error, req, *args, **kwargs):
def _handle_invalid_json_error(self, error, req, *args, **kwargs):
messages = {"json": ["Invalid JSON body."]}
response = exception_response(
400, detail=text_type(messages), content_type="application/json"
Expand All @@ -113,7 +104,7 @@ def use_args(
self,
argmap,
req=None,
locations=core.Parser.DEFAULT_LOCATIONS,
location=core.Parser.DEFAULT_LOCATION,
as_kwargs=False,
validate=None,
error_status_code=None,
Expand All @@ -127,7 +118,7 @@ def use_args(
of argname -> `marshmallow.fields.Field` pairs, or a callable
which accepts a request and returns a `marshmallow.Schema`.
:param req: The request object to parse. Pulled off of the view by default.
:param tuple locations: Where on the request to search for values.
:param str location: Where on the request to load values.
:param bool as_kwargs: Whether to insert arguments as keyword arguments.
:param callable validate: Validation function that receives the dictionary
of parsed arguments. If the function returns ``False``, the parser
Expand All @@ -137,7 +128,7 @@ def use_args(
:param dict error_headers: Headers passed to error handler functions when a
a `ValidationError` is raised.
"""
locations = locations or self.locations
location = location or self.location
# Optimization: If argmap is passed as a dictionary, we only need
# to generate a Schema once
if isinstance(argmap, collections.Mapping):
Expand All @@ -155,7 +146,7 @@ def wrapper(obj, *args, **kwargs):
parsed_args = self.parse(
argmap,
req=request,
locations=locations,
location=location,
validate=validate,
error_status_code=error_status_code,
error_headers=error_headers,
Expand Down
63 changes: 51 additions & 12 deletions tests/apps/pyramid_app.py
Expand Up @@ -19,8 +19,22 @@ class HelloSchema(ma.Schema):
strict_kwargs = {"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {}
hello_many_schema = HelloSchema(many=True, **strict_kwargs)

# variant which ignores unknown fields
exclude_kwargs = (
{"strict": True} if MARSHMALLOW_VERSION_INFO[0] < 3 else {"unknown": ma.EXCLUDE}
)
hello_exclude_schema = HelloSchema(**exclude_kwargs)


def echo(request):
return parser.parse(hello_args, request, location="query")


def echo_form(request):
return parser.parse(hello_args, request, location="form")


def echo_json(request):
try:
return parser.parse(hello_args, request)
except json.JSONDecodeError:
Expand All @@ -30,39 +44,59 @@ def echo(request):
raise error


def echo_json_ignore_extra_data(request):
try:
return parser.parse(hello_exclude_schema, request)
except json.JSONDecodeError:
error = HTTPBadRequest()
error.body = json.dumps(["Invalid JSON."]).encode("utf-8")
error.content_type = "application/json"
raise error


def echo_query(request):
return parser.parse(hello_args, request, locations=("query",))
return parser.parse(hello_args, request, location="query")


@use_args(hello_args)
@use_args(hello_args, location="query")
def echo_use_args(request, args):
return args


@use_args({"value": fields.Int()}, validate=lambda args: args["value"] > 42)
@use_args(
{"value": fields.Int()}, validate=lambda args: args["value"] > 42, location="form"
)
def echo_use_args_validated(request, args):
return args


@use_kwargs(hello_args)
@use_kwargs(hello_args, location="query")
def echo_use_kwargs(request, name):
return {"name": name}


def echo_multi(request):
return parser.parse(hello_multiple, request, location="query")


def echo_multi_form(request):
return parser.parse(hello_multiple, request, location="form")


def echo_multi_json(request):
return parser.parse(hello_multiple, request)


def echo_many_schema(request):
return parser.parse(hello_many_schema, request, locations=("json",))
return parser.parse(hello_many_schema, request)


@use_args({"value": fields.Int()})
@use_args({"value": fields.Int()}, location="query")
def echo_use_args_with_path_param(request, args):
return args


@use_kwargs({"value": fields.Int()})
@use_kwargs({"value": fields.Int()}, location="query")
def echo_use_kwargs_with_path_param(request, value):
return {"value": value}

Expand All @@ -76,16 +110,16 @@ def always_fail(value):


def echo_headers(request):
return parser.parse(hello_args, request, locations=("headers",))
return parser.parse(hello_exclude_schema, request, location="headers")


def echo_cookie(request):
return parser.parse(hello_args, request, locations=("cookies",))
return parser.parse(hello_args, request, location="cookies")


def echo_file(request):
args = {"myfile": fields.Field()}
result = parser.parse(args, request, locations=("files",))
result = parser.parse(args, request, location="files")
myfile = result["myfile"]
content = myfile.file.read().decode("utf8")
return {"myfile": content}
Expand All @@ -104,14 +138,14 @@ def echo_nested_many(request):


def echo_matchdict(request):
return parser.parse({"mymatch": fields.Int()}, request, locations=("matchdict",))
return parser.parse({"mymatch": fields.Int()}, request, location="matchdict")


class EchoCallable(object):
def __init__(self, request):
self.request = request

@use_args({"value": fields.Int()})
@use_args({"value": fields.Int()}, location="query")
def __call__(self, args):
return args

Expand All @@ -127,11 +161,16 @@ def create_app():
config = Configurator()

add_route(config, "/echo", echo)
add_route(config, "/echo_form", echo_form)
add_route(config, "/echo_json", echo_json)
add_route(config, "/echo_query", echo_query)
add_route(config, "/echo_ignoring_extra_data", echo_json_ignore_extra_data)
add_route(config, "/echo_use_args", echo_use_args)
add_route(config, "/echo_use_args_validated", echo_use_args_validated)
add_route(config, "/echo_use_kwargs", echo_use_kwargs)
add_route(config, "/echo_multi", echo_multi)
add_route(config, "/echo_multi_form", echo_multi_form)
add_route(config, "/echo_multi_json", echo_multi_json)
add_route(config, "/echo_many_schema", echo_many_schema)
add_route(
config, "/echo_use_args_with_path_param/{name}", echo_use_args_with_path_param
Expand Down

0 comments on commit f005c40

Please sign in to comment.