Skip to content

Commit

Permalink
Merge pull request #847 from marshmallow-code/typing-improvements
Browse files Browse the repository at this point in the history
Make Parser a generic type over Request
  • Loading branch information
sirosen committed Jun 14, 2023
2 parents b038456 + ed376e7 commit 5f8d589
Show file tree
Hide file tree
Showing 12 changed files with 138 additions and 89 deletions.
12 changes: 11 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@ Changelog
8.3.0 (unreleased)
******************

Features:

* ``webargs.Parser`` now inherits from ``typing.Generic`` and is parametrizable
over the type of the request object. Various framework-specific parsers are
parametrized over their relevant request object classes.

Other changes:

- Test against Python 3.11 (:pr:`787`).
* Type annotations have been improved to allow ``Mapping`` for dict-like
schemas where previously ``dict`` was used. This makes the type covariant
rather than invariant (:issue:`836`).

* Test against Python 3.11 (:pr:`787`).

8.2.0 (2022-07-11)
******************
Expand Down
2 changes: 1 addition & 1 deletion src/webargs/aiohttpparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _find_exceptions() -> None:
del _find_exceptions


class AIOHTTPParser(AsyncParser):
class AIOHTTPParser(AsyncParser[web.Request]):
"""aiohttp request argument parser."""

DEFAULT_UNKNOWN_BY_LOCATION: dict[str, str | None] = {
Expand Down
2 changes: 1 addition & 1 deletion src/webargs/asyncparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from webargs import core


class AsyncParser(core.Parser):
class AsyncParser(core.Parser[core.Request]):
"""Asynchronous variant of `webargs.core.Parser`.
The ``parse`` method is redefined to be ``async``.
Expand Down
2 changes: 1 addition & 1 deletion src/webargs/bottleparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def index(args):
from webargs import core


class BottleParser(core.Parser):
class BottleParser(core.Parser[bottle.Request]):
"""Bottle.py request argument parser."""

def _handle_invalid_json_error(self, error, req, *args, **kwargs):
Expand Down
13 changes: 9 additions & 4 deletions src/webargs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
Request = typing.TypeVar("Request")
ArgMap = typing.Union[
ma.Schema,
typing.Dict[str, typing.Union[ma.fields.Field, typing.Type[ma.fields.Field]]],
typing.Mapping[str, typing.Union[ma.fields.Field, typing.Type[ma.fields.Field]]],
typing.Callable[[Request], ma.Schema],
]

ValidateArg = typing.Union[None, typing.Callable, typing.Iterable[typing.Callable]]
CallableList = typing.List[typing.Callable]
ErrorHandler = typing.Callable[..., typing.NoReturn]
Expand Down Expand Up @@ -115,7 +116,7 @@ def _ensure_list_of_callables(obj: typing.Any) -> CallableList:
return validators


class Parser:
class Parser(typing.Generic[Request]):
"""Base parser class that provides high-level implementation for parsing
a request.
Expand Down Expand Up @@ -302,7 +303,9 @@ def _get_schema(self, argmap: ArgMap, req: Request) -> ma.Schema:
schema = argmap()
elif callable(argmap):
schema = argmap(req)
elif isinstance(argmap, dict):
elif isinstance(argmap, typing.Mapping):
if not isinstance(argmap, dict):
argmap = dict(argmap)
schema = self.schema_class.from_dict(argmap)()
else:
raise TypeError(f"argmap was of unexpected type {type(argmap)}")
Expand Down Expand Up @@ -541,7 +544,9 @@ def greet(args):
request_obj = req
# Optimization: If argmap is passed as a dictionary, we only need
# to generate a Schema once
if isinstance(argmap, dict):
if isinstance(argmap, typing.Mapping):
if not isinstance(argmap, dict):
argmap = dict(argmap)
argmap = self.schema_class.from_dict(argmap)()

def decorator(func: typing.Callable) -> typing.Callable:
Expand Down
16 changes: 9 additions & 7 deletions src/webargs/djangoparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ class MyView(View):
def get(self, args, request):
return HttpResponse('Hello ' + args['name'])
"""
import django

from webargs import core


def is_json_request(req):
return core.is_json(req.content_type)


class DjangoParser(core.Parser):
class DjangoParser(core.Parser[django.http.HttpRequest]):
"""Django request argument parser.
.. warning::
Expand All @@ -35,7 +37,7 @@ class DjangoParser(core.Parser):
the parser and returning the appropriate `HTTPResponse`.
"""

def _raw_load_json(self, req):
def _raw_load_json(self, req: django.http.HttpRequest):
"""Read a json payload from the request for the core parser's load_json
Checks the input mimetype and may return 'missing' if the mimetype is
Expand All @@ -45,25 +47,25 @@ def _raw_load_json(self, req):

return core.parse_json(req.body)

def load_querystring(self, req, schema):
def load_querystring(self, req: django.http.HttpRequest, schema):
"""Return query params from the request as a MultiDictProxy."""
return self._makeproxy(req.GET, schema)

def load_form(self, req, schema):
def load_form(self, req: django.http.HttpRequest, schema):
"""Return form values from the request as a MultiDictProxy."""
return self._makeproxy(req.POST, schema)

def load_cookies(self, req, schema):
def load_cookies(self, req: django.http.HttpRequest, schema):
"""Return cookies from the request."""
return req.COOKIES

def load_headers(self, req, schema):
def load_headers(self, req: django.http.HttpRequest, schema):
"""Return headers from the request."""
# Django's HttpRequest.headers is a case-insensitive dict type, but it
# isn't a multidict, so this is not proxied
return req.headers

def load_files(self, req, schema):
def load_files(self, req: django.http.HttpRequest, schema):
"""Return files from the request as a MultiDictProxy."""
return self._makeproxy(req.FILES, schema)

Expand Down
29 changes: 15 additions & 14 deletions src/webargs/falconparser.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Falcon request argument parsing module.
"""
import falcon
from falcon.util.uri import parse_query_string

import marshmallow as ma
from falcon.util.uri import parse_query_string

from webargs import core

Expand All @@ -25,13 +24,13 @@ def _find_exceptions():
del _find_exceptions


def is_json_request(req):
def is_json_request(req: falcon.Request):
content_type = req.get_header("Content-Type")
return content_type and core.is_json(content_type)


# NOTE: Adapted from falcon.request.Request._parse_form_urlencoded
def parse_form_body(req):
def parse_form_body(req: falcon.Request):
if (
req.content_type is not None
and "application/x-www-form-urlencoded" in req.content_type
Expand Down Expand Up @@ -69,7 +68,7 @@ def to_dict(self, *args, **kwargs):
return ret


class FalconParser(core.Parser):
class FalconParser(core.Parser[falcon.Request]):
"""Falcon request argument parser.
Defaults to using the `media` location. See :py:meth:`~FalconParser.load_media` for
Expand All @@ -94,11 +93,11 @@ class FalconParser(core.Parser):
# values should be wrapped in lists due to the type of the destination
# field

def load_querystring(self, req, schema):
def load_querystring(self, req: falcon.Request, schema):
"""Return query params from the request as a MultiDictProxy."""
return self._makeproxy(req.params, schema)

def load_form(self, req, schema):
def load_form(self, req: falcon.Request, schema):
"""Return form values from the request as a MultiDictProxy
.. note::
Expand All @@ -110,7 +109,7 @@ def load_form(self, req, schema):
return form
return self._makeproxy(form, schema)

def load_media(self, req, schema):
def load_media(self, req: falcon.Request, schema):
"""Return data unpacked and parsed by one of Falcon's media handlers.
By default, Falcon only handles JSON payloads.
Expand All @@ -129,7 +128,7 @@ def load_media(self, req, schema):
return core.missing
return req.media

def _raw_load_json(self, req):
def _raw_load_json(self, req: falcon.Request):
"""Return a json payload from the request for the core parser's load_json
Checks the input mimetype and may return 'missing' if the mimetype is
Expand All @@ -141,12 +140,12 @@ def _raw_load_json(self, req):
return core.parse_json(body)
return core.missing

def load_headers(self, req, schema):
def load_headers(self, req: falcon.Request, schema):
"""Return headers from the request."""
# Falcon only exposes headers as a dict (not multidict)
return req.headers

def load_cookies(self, req, schema):
def load_cookies(self, req: falcon.Request, schema):
"""Return cookies from the request."""
# Cookies are expressed in Falcon as a dict, but the possibility of
# multiple values for a cookie is preserved internally -- if desired in
Expand All @@ -163,19 +162,21 @@ def get_request_from_view_args(self, view, args, kwargs):
raise TypeError("Argument is not a falcon.Request")
return req

def load_files(self, req, schema):
def load_files(self, req: falcon.Request, schema):
raise NotImplementedError(
f"Parsing files not yet supported by {self.__class__.__name__}"
)

def handle_error(self, error, req, schema, *, error_status_code, error_headers):
def handle_error(
self, error, req: falcon.Request, schema, *, error_status_code, error_headers
):
"""Handles errors during parsing."""
status = status_map.get(error_status_code or self.DEFAULT_VALIDATION_STATUS)
if status is None:
raise LookupError(f"Status code {error_status_code} not supported")
raise HTTPError(status, errors=error.messages, headers=error_headers)

def _handle_invalid_json_error(self, error, req, *args, **kwargs):
def _handle_invalid_json_error(self, error, req: falcon.Request, *args, **kwargs):
status = status_map[400]
messages = {"json": ["Invalid JSON body."]}
raise HTTPError(status, errors=messages)
Expand Down
30 changes: 16 additions & 14 deletions src/webargs/flaskparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ def user_detail(args, uid):
)
"""
from __future__ import annotations

from typing import NoReturn

import flask
from werkzeug.exceptions import HTTPException

import marshmallow as ma
from werkzeug.exceptions import HTTPException

from webargs import core

Expand All @@ -45,11 +45,11 @@ def abort(http_status_code, exc=None, **kwargs) -> NoReturn:
raise err


def is_json_request(req):
def is_json_request(req: flask.Request):
return core.is_json(req.mimetype)


class FlaskParser(core.Parser):
class FlaskParser(core.Parser[flask.Request]):
"""Flask request argument parser."""

DEFAULT_UNKNOWN_BY_LOCATION: dict[str, str | None] = {
Expand All @@ -63,7 +63,7 @@ class FlaskParser(core.Parser):
**core.Parser.__location_map__,
)

def _raw_load_json(self, req):
def _raw_load_json(self, req: flask.Request):
"""Return a json payload from the request for the core parser's load_json
Checks the input mimetype and may return 'missing' if the mimetype is
Expand All @@ -73,34 +73,36 @@ def _raw_load_json(self, req):

return core.parse_json(req.get_data(cache=True))

def _handle_invalid_json_error(self, error, req, *args, **kwargs):
def _handle_invalid_json_error(self, error, req: flask.Request, *args, **kwargs):
abort(400, exc=error, messages={"json": ["Invalid JSON body."]})

def load_view_args(self, req, schema):
def load_view_args(self, req: flask.Request, schema):
"""Return the request's ``view_args`` or ``missing`` if there are none."""
return req.view_args or core.missing

def load_querystring(self, req, schema):
def load_querystring(self, req: flask.Request, schema):
"""Return query params from the request as a MultiDictProxy."""
return self._makeproxy(req.args, schema)

def load_form(self, req, schema):
def load_form(self, req: flask.Request, schema):
"""Return form values from the request as a MultiDictProxy."""
return self._makeproxy(req.form, schema)

def load_headers(self, req, schema):
def load_headers(self, req: flask.Request, schema):
"""Return headers from the request as a MultiDictProxy."""
return self._makeproxy(req.headers, schema)

def load_cookies(self, req, schema):
def load_cookies(self, req: flask.Request, schema):
"""Return cookies from the request."""
return req.cookies

def load_files(self, req, schema):
def load_files(self, req: flask.Request, schema):
"""Return files from the request as a MultiDictProxy."""
return self._makeproxy(req.files, schema)

def handle_error(self, error, req, schema, *, error_status_code, error_headers):
def handle_error(
self, error, req: flask.Request, schema, *, error_status_code, error_headers
):
"""Handles errors during parsing. Aborts the current HTTP request and
responds with a 422 error.
"""
Expand All @@ -113,7 +115,7 @@ def handle_error(self, error, req, schema, *, error_status_code, error_headers):
headers=error_headers,
)

def get_default_request(self):
def get_default_request(self) -> flask.Request:
"""Override to use Flask's thread-local request object by default"""
return flask.request

Expand Down

0 comments on commit 5f8d589

Please sign in to comment.