diff --git a/starlette/authentication.py b/starlette/authentication.py index b4882070d5..9384f1045a 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -6,7 +6,7 @@ from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection, Request from starlette.responses import RedirectResponse, Response -from starlette.websockets import WebSocket +from starlette.websockets import WebSocket, WebsocketDenialResponse def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool: @@ -46,9 +46,17 @@ async def websocket_wrapper( assert isinstance(websocket, WebSocket) if not has_required_scope(websocket, scopes_list): - await websocket.close() - else: - await func(*args, **kwargs) + if redirect is not None: + response = WebsocketDenialResponse( + RedirectResponse( + url=websocket.url_for(redirect), status_code=303 + ) + ) + await response.send(websocket) + else: + raise HTTPException(status_code=status_code) + + await func(*args, **kwargs) return websocket_wrapper @@ -66,7 +74,9 @@ async def async_wrapper( return RedirectResponse( url=request.url_for(redirect), status_code=303 ) - raise HTTPException(status_code=status_code) + else: + raise HTTPException(status_code=status_code) + return await func(*args, **kwargs) return async_wrapper @@ -83,7 +93,9 @@ def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: return RedirectResponse( url=request.url_for(redirect), status_code=303 ) - raise HTTPException(status_code=status_code) + else: + raise HTTPException(status_code=status_code) + return func(*args, **kwargs) return sync_wrapper diff --git a/starlette/middleware/authentication.py b/starlette/middleware/authentication.py index 6e2d2dade3..72c76f4200 100644 --- a/starlette/middleware/authentication.py +++ b/starlette/middleware/authentication.py @@ -9,6 +9,7 @@ from starlette.requests import HTTPConnection from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.websockets import WebsocketDenialResponse class AuthenticationMiddleware: @@ -37,9 +38,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: except AuthenticationError as exc: response = self.on_error(conn, exc) if scope["type"] == "websocket": - await send({"type": "websocket.close", "code": 1000}) - else: - await response(scope, receive, send) + response = WebsocketDenialResponse(response) + await response(scope, receive, send) return if auth_result is None: