Skip to content

Commit

Permalink
Update AuthenticationMiddleware and @requires to use WebsocketDenialR…
Browse files Browse the repository at this point in the history
…esponse
  • Loading branch information
paulo-raca committed Feb 9, 2022
1 parent 89fe8f5 commit 76c36e3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
24 changes: 18 additions & 6 deletions starlette/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection, Request
from starlette.responses import RedirectResponse, Response
from starlette.websockets import WebSocket
from starlette.websockets import WebSocket, WebsocketDenialResponse


def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
Expand Down Expand Up @@ -46,9 +46,17 @@ async def websocket_wrapper(
assert isinstance(websocket, WebSocket)

if not has_required_scope(websocket, scopes_list):
await websocket.close()
else:
await func(*args, **kwargs)
if redirect is not None:
response = WebsocketDenialResponse(
RedirectResponse(
url=websocket.url_for(redirect), status_code=303
)
)
await response.send(websocket)
else:
raise HTTPException(status_code=status_code)

await func(*args, **kwargs)

return websocket_wrapper

Expand All @@ -66,7 +74,9 @@ async def async_wrapper(
return RedirectResponse(
url=request.url_for(redirect), status_code=303
)
raise HTTPException(status_code=status_code)
else:
raise HTTPException(status_code=status_code)

return await func(*args, **kwargs)

return async_wrapper
Expand All @@ -83,7 +93,9 @@ def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
return RedirectResponse(
url=request.url_for(redirect), status_code=303
)
raise HTTPException(status_code=status_code)
else:
raise HTTPException(status_code=status_code)

return func(*args, **kwargs)

return sync_wrapper
Expand Down
6 changes: 3 additions & 3 deletions starlette/middleware/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from starlette.requests import HTTPConnection
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebsocketDenialResponse


class AuthenticationMiddleware:
Expand Down Expand Up @@ -37,9 +38,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
except AuthenticationError as exc:
response = self.on_error(conn, exc)
if scope["type"] == "websocket":
await send({"type": "websocket.close", "code": 1000})
else:
await response(scope, receive, send)
response = WebsocketDenialResponse(response)
await response(scope, receive, send)
return

if auth_result is None:
Expand Down

0 comments on commit 76c36e3

Please sign in to comment.