Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve http decorator typing #75541

Merged
merged 1 commit into from
Jul 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion homeassistant/components/auth/login_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ async def get(self, request):

@RequestDataValidator(vol.Schema({"client_id": str}, extra=vol.ALLOW_EXTRA))
@log_invalid_auth
async def post(self, request, flow_id, data):
async def post(self, request, data, flow_id):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it have been typed here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #75586 to fully type the auth integration.

"""Handle progressing a login flow request."""
client_id = data.pop("client_id")

Expand Down
18 changes: 11 additions & 7 deletions homeassistant/components/http/ban.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Coroutine
from contextlib import suppress
from datetime import datetime
from http import HTTPStatus
from ipaddress import IPv4Address, IPv6Address, ip_address
import logging
from socket import gethostbyaddr, herror
from typing import Any, Final
from typing import Any, Final, TypeVar

from aiohttp.web import Application, Request, StreamResponse, middleware
from aiohttp.web import Application, Request, Response, StreamResponse, middleware
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
from typing_extensions import Concatenate, ParamSpec
import voluptuous as vol

from homeassistant.components import persistent_notification
Expand All @@ -24,6 +25,9 @@

from .view import HomeAssistantView

_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
_P = ParamSpec("_P")

_LOGGER: Final = logging.getLogger(__name__)

KEY_BAN_MANAGER: Final = "ha_banned_ips_manager"
Expand Down Expand Up @@ -82,13 +86,13 @@ async def ban_middleware(


def log_invalid_auth(
func: Callable[..., Awaitable[StreamResponse]]
) -> Callable[..., Awaitable[StreamResponse]]:
func: Callable[Concatenate[_HassViewT, Request, _P], Awaitable[Response]]
) -> Callable[Concatenate[_HassViewT, Request, _P], Coroutine[Any, Any, Response]]:
"""Decorate function to handle invalid auth or failed login attempts."""

async def handle_req(
view: HomeAssistantView, request: Request, *args: Any, **kwargs: Any
) -> StreamResponse:
view: _HassViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs
) -> Response:
"""Try to log failed login attempts if response status >= BAD_REQUEST."""
resp = await func(view, request, *args, **kwargs)
if resp.status >= HTTPStatus.BAD_REQUEST:
Expand Down
33 changes: 22 additions & 11 deletions homeassistant/components/http/data_validator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
"""Decorator for view methods to help with data validation."""
from __future__ import annotations

from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Coroutine
from functools import wraps
from http import HTTPStatus
import logging
from typing import Any
from typing import Any, TypeVar

from aiohttp import web
from typing_extensions import Concatenate, ParamSpec
import voluptuous as vol

from .view import HomeAssistantView

_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
_P = ParamSpec("_P")

_LOGGER = logging.getLogger(__name__)


Expand All @@ -33,33 +37,40 @@ def __init__(self, schema: vol.Schema, allow_empty: bool = False) -> None:
self._allow_empty = allow_empty

def __call__(
self, method: Callable[..., Awaitable[web.StreamResponse]]
) -> Callable:
self,
method: Callable[
Concatenate[_HassViewT, web.Request, dict[str, Any], _P],
Awaitable[web.Response],
],
) -> Callable[
Concatenate[_HassViewT, web.Request, _P],
Coroutine[Any, Any, web.Response],
]:
"""Decorate a function."""

@wraps(method)
async def wrapper(
view: HomeAssistantView, request: web.Request, *args: Any, **kwargs: Any
) -> web.StreamResponse:
view: _HassViewT, request: web.Request, *args: _P.args, **kwargs: _P.kwargs
) -> web.Response:
"""Wrap a request handler with data validation."""
data = None
raw_data = None
try:
data = await request.json()
raw_data = await request.json()
except ValueError:
if not self._allow_empty or (await request.content.read()) != b"":
_LOGGER.error("Invalid JSON received")
return view.json_message("Invalid JSON.", HTTPStatus.BAD_REQUEST)
data = {}
raw_data = {}

try:
kwargs["data"] = self._schema(data)
data: dict[str, Any] = self._schema(raw_data)
except vol.Invalid as err:
_LOGGER.error("Data does not match schema: %s", err)
return view.json_message(
f"Message format incorrect: {err}", HTTPStatus.BAD_REQUEST
)

result = await method(view, request, *args, **kwargs)
result = await method(view, request, data, *args, **kwargs)
return result

return wrapper
4 changes: 2 additions & 2 deletions homeassistant/components/repairs/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response

result = self._prepare_result_json(result)

return self.json(result) # pylint: disable=arguments-differ
return self.json(result)


class RepairsFlowResourceView(FlowManagerResourceView):
Expand All @@ -135,4 +135,4 @@ async def post(self, request: web.Request, flow_id: str) -> web.Response:
raise Unauthorized(permission=POLICY_EDIT)

# pylint: disable=no-value-for-parameter
return await super().post(request, flow_id) # type: ignore[no-any-return]
return await super().post(request, flow_id)
2 changes: 1 addition & 1 deletion homeassistant/helpers/data_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def get(self, request: web.Request, flow_id: str) -> web.Response:

@RequestDataValidator(vol.Schema(dict), allow_empty=True)
async def post(
self, request: web.Request, flow_id: str, data: dict[str, Any]
self, request: web.Request, data: dict[str, Any], flow_id: str
) -> web.Response:
"""Handle a POST request."""
try:
Expand Down